Changeset 95d58d3 in sasview for src/sans/fit/BumpsFitting.py
- Timestamp:
- Apr 10, 2014 8:05:28 PM (10 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- 90f49a8
- Parents:
- 6fe5100
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
src/sans/fit/BumpsFitting.py
r6fe5100 r95d58d3 3 3 """ 4 4 import sys 5 import copy6 5 7 6 import numpy … … 13 12 from sans.fit.AbstractFitEngine import FResult 14 13 15 class Sa nsAssembly(object):14 class SasProblem(object): 16 15 """ 17 Sans Assembly class a class wrapper to be call in optimizer.leastsq method16 Wrap the SAS model in a form that can be understood by bumps. 18 17 """ 19 def __init__(self, param list, model=None, data=None, fitresult=None,18 def __init__(self, param_list, model=None, data=None, fitresult=None, 20 19 handler=None, curr_thread=None, msg_q=None): 21 20 """ … … 25 24 self.model = model 26 25 self.data = data 27 self.param list = paramlist26 self.param_list = param_list 28 27 self.msg_q = msg_q 29 28 self.curr_thread = curr_thread … … 37 36 @property 38 37 def dof(self): 39 return self.data.num_points - len(self.param list)38 return self.data.num_points - len(self.param_list) 40 39 41 40 def summarize(self): 42 return "summarize" 43 44 def nllf(self, pvec=None): 45 residuals = self.residuals(pvec) 41 """ 42 Return a stylized list of parameter names and values with range bars 43 suitable for printing. 44 """ 45 output = [] 46 bounds = self.bounds() 47 for i,p in enumerate(self.getp()): 48 name = self.param_list[i] 49 low,high = bounds[:,i] 50 range = ",".join((("[%g"%low if numpy.isfinite(low) else "(-inf"), 51 ("%g]"%high if numpy.isfinite(high) else "inf)"))) 52 if not numpy.isfinite(p): 53 bar = "*invalid* " 54 else: 55 bar = ['.']*10 56 if numpy.isfinite(high-low): 57 position = int(9.999999999 * float(p-low)/float(high-low)) 58 if position < 0: bar[0] = '<' 59 elif position > 9: bar[9] = '>' 60 else: bar[position] = '|' 61 bar = "".join(bar) 62 output.append("%40s %s %10g in %s"%(name,bar,p,range)) 63 return "\n".join(output) 64 65 def nllf(self, p=None): 66 residuals = self.residuals(p) 46 67 return 0.5*numpy.sum(residuals**2) 47 68 48 def setp(self, params): 49 self.model.set_params(self.paramlist, params) 69 def setp(self, p): 70 for k,v in zip(self.param_list, p): 71 self.model.setParam(k,v) 72 #self.model.set_params(self.param_list, params) 50 73 51 74 def getp(self): 52 return numpy.asarray(self.model.get_params(self.paramlist)) 75 return numpy.array([self.model.getParam(k) for k in self.param_list]) 76 #return numpy.asarray(self.model.get_params(self.param_list)) 53 77 54 78 def bounds(self): 55 return numpy.array([self._getrange(p) for p in self.param list]).T79 return numpy.array([self._getrange(p) for p in self.param_list]).T 56 80 57 81 def labels(self): 58 return self.param list82 return self.param_list 59 83 60 84 def _getrange(self, p): … … 63 87 return the range of parameter 64 88 """ 65 lo, hi = self.model. model.details[p][1:3]89 lo, hi = self.model.details[p][1:3] 66 90 if lo is None: lo = -numpy.inf 67 91 if hi is None: hi = numpy.inf … … 69 93 70 94 def randomize(self, n): 71 p vec= self.getp()95 p = self.getp() 72 96 # since randn is symmetric and random, doesn't matter 73 97 # point value is negative. 74 98 # TODO: throw in bounds checking! 75 return numpy.random.randn(n, len(self.param list))*pvec + pvec99 return numpy.random.randn(n, len(self.param_list))*p + p 76 100 77 101 def chisq(self): … … 84 108 85 109 """ 86 total = 0 87 for item in self.res: 88 total += item * item 89 if len(self.res) == 0: 90 return None 91 return total / len(self.res) 110 return numpy.sum(self.res**2)/self.dof 92 111 93 112 def residuals(self, params=None): … … 99 118 #import thread 100 119 #print "params", params 101 self.res, self.theory = self.data.residuals(self.model.eval) 102 120 self.res, self.theory = self.data.residuals(self.model.evalDistribution) 121 122 # TODO: this belongs in monitor not residuals calculation 103 123 if self.fitresult is not None: 104 self.fitresult.set_model(model=self.model)124 #self.fitresult.set_model(model=self.model) 105 125 self.fitresult.residuals = self.res+0 106 126 self.fitresult.iterations += 1 … … 109 129 #fitness = self.chisq(params=params) 110 130 fitness = self.chisq() 111 self.fitresult.p vec= params131 self.fitresult.p = params 112 132 self.fitresult.set_fitness(fitness=fitness) 113 133 if self.msg_q is not None: … … 131 151 __call__ = residuals 132 152 133 def check_param_range(self):153 def _DEAD_check_param_range(self): 134 154 """ 135 155 Check the lower and upper bound of the parameter value … … 142 162 is_outofbound = False 143 163 # loop through the fit parameters 144 model = self.model .model145 for p in self.param list:164 model = self.model 165 for p in self.param_list: 146 166 value = model.getParam(p) 147 167 low,high = model.details[p][1:3] … … 196 216 raise RuntimeError, msg 197 217 elif len(fitproblem) == 0 : 198 raise RuntimeError, "No Assembly scheduled for Scipyfitting."218 raise RuntimeError, "No problem scheduled for fitting." 199 219 model = fitproblem[0].get_model() 200 220 if reset_flag: … … 203 223 ind = fitproblem[0].pars.index(name) 204 224 model.setParam(name, fitproblem[0].vals[ind]) 205 listdata = []206 225 listdata = fitproblem[0].get_data() 207 226 # Concatenate dList set (contains one or more data)before fitting … … 209 228 210 229 self.curr_thread = curr_thread 211 ftol = ftol212 230 213 231 result = FResult(model=model, data=data, param_list=self.param_list) … … 217 235 if handler is not None: 218 236 handler.set_result(result=result) 219 functor = SansAssembly(paramlist=self.param_list,220 model=model,221 222 223 224 225 237 problem = SasProblem(param_list=self.param_list, 238 model=model.model, 239 data=data, 240 handler=handler, 241 fitresult=result, 242 curr_thread=curr_thread, 243 msg_q=msg_q) 226 244 try: 227 run_bumps(functor, result) 245 #run_bumps(problem, result, ftol) 246 run_scipy(problem, result, ftol) 228 247 except: 229 248 if hasattr(sys, 'last_type') and sys.last_type == KeyboardInterrupt: … … 245 264 return [result] 246 265 247 def run_bumps(problem, result ):266 def run_bumps(problem, result, ftol): 248 267 fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT] 249 fitdriver = fitters.FitDriver(fitopts.fitclass, problem=problem, 250 abort_test=lambda: False, **fitopts.options) 268 fitclass = fitopts.fitclass 269 options = fitopts.options.copy() 270 options['ftol'] = ftol 271 fitdriver = fitters.FitDriver(fitclass, problem=problem, 272 abort_test=lambda: False, **options) 251 273 mapper = SerialMapper 252 274 fitdriver.mapper = mapper.start_mapper(problem, None) … … 256 278 import traceback; traceback.print_exc() 257 279 raise 258 mapper.stop_mapper(fitdriver.mapper) 259 fitdriver.show() 260 #fitdriver.plot() 261 result.fitness = fbest * 2. / len(result.pars) 262 result.stderr = numpy.ones(len(result.pars)) 263 result.pvec = best 280 finally: 281 mapper.stop_mapper(fitdriver.mapper) 282 #print "best,fbest",best,fbest,problem.dof 283 result.fitness = 2*fbest/problem.dof 284 #print "fitness",result.fitness 285 result.stderr = fitdriver.stderr() 286 result.pvec = best 287 # TODO: track success better 264 288 result.success = True 265 289 result.theory = problem.theory 266 290 267 def run_scipy(model, result ):291 def run_scipy(model, result, ftol): 268 292 # This import must be here; otherwise it will be confused when more 269 293 # than one thread exist. 270 294 from scipy import optimize 271 295 272 out, cov_x, _, mesg, success = optimize.leastsq( functor,273 model.get _params(self.param_list),296 out, cov_x, _, mesg, success = optimize.leastsq(model.residuals, 297 model.getp(), 274 298 ftol=ftol, 275 299 full_output=1) … … 278 302 else: 279 303 stderr = [] 280 result.fitness = functor.chisqr()304 result.fitness = model.chisq() 281 305 result.stderr = stderr 282 306 result.pvec = out 283 307 result.success = success 284 result.theory = functor.theory285 308 result.theory = model.theory 309
Note: See TracChangeset
for help on using the changeset viewer.