Changeset fbc51ef in sasview
- Timestamp:
- Jun 19, 2008 4:08:51 PM (16 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- 7681bac
- Parents:
- 3701620
- File:
-
- 1 moved
Legend:
- Unmodified
- Added
- Removed
-
park_integration/test/ParkFitting.py
r197ea24 rfbc51ef 1 1 #class Fitting 2 import time 3 4 import numpy 5 import park 6 from scipy import optimize 7 from park import fit,fitresult 8 from park import assembly 9 2 10 from sans.guitools.plottables import Data1D 11 #from sans.guitools import plottables 3 12 from Loader import Load 4 from scipy import optimize 5 6 13 14 class SansParameter(park.Parameter): 15 """ 16 SANS model parameters for use in the PARK fitting service. 17 The parameter attribute value is redirected to the underlying 18 parameter value in the SANS model. 19 """ 20 def __init__(self, name, model): 21 self._model, self._name = model,name 22 def _getvalue(self): return self._model.getParam(self.name) 23 def _setvalue(self,value): self._model.setParam(self.name, value) 24 value = property(_getvalue,_setvalue) 25 def _getrange(self): 26 lo,hi = self._model.details[self.name][1:] 27 if lo is None: lo = -numpy.inf 28 if hi is None: hi = numpy.inf 29 return lo,hi 30 def _setrange(self,r): 31 self._model.details[self.name][1:] = r 32 range = property(_getrange,_setrange) 33 34 class Model(object): 35 """ 36 PARK wrapper for SANS models. 37 """ 38 def __init__(self, sans_model): 39 self.model = sans_model 40 sansp = sans_model.getParamList() 41 parkp = [SansParameter(p,sans_model) for p in sansp] 42 self.parameterset = park.ParameterSet(sans_model.name,pars=parkp) 43 def eval(self,x): 44 return self.model.run(x) 45 46 class Data(object): 47 """ Wrapper class for SANS data """ 48 def __init__(self, sans_data): 49 self.x= sans_data.x 50 self.y= sans_data.y 51 self.dx= sans_data.dx 52 self.dy= sans_data.dy 53 self.qmin=None 54 self.qmax=None 55 56 def setFitRange(self,mini=None,maxi=None): 57 """ to set the fit range""" 58 self.qmin=mini 59 self.qmax=maxi 60 61 def residuals(self, fn): 62 63 x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)] 64 if self.qmin==None and self.qmax==None: 65 return (y - fn(x))/dy 66 67 else: 68 idx = x>=self.qmin & x <= self.qmax 69 return (y[idx] - fn(x[idx]))/dy[idx] 70 71 72 def residuals_deriv(self, model, pars=[]): 73 """ Return residual derivatives .in this case just return empty array""" 74 return [] 75 7 76 class FitArrange: 8 77 def __init__(self): … … 43 112 self.dList.remove(data) 44 113 45 class Fitting:114 class ParkFit: 46 115 """ 47 116 Performs the Fit.he user determine what kind of data … … 55 124 self.fitType =None 56 125 57 def fit_engine(self,word):126 def createProblem(self,pars={}): 58 127 """ 59 128 Check the contraint value and specify what kind of fit to use 60 """ 61 self.fitType = word 62 return True 129 return (M1,D1) 130 """ 131 mylist=[] 132 for k,value in self.fitArrangeList.iteritems(): 133 couple=() 134 model=value.get_model() 135 parameters= self.set_param(model, pars) 136 model = Model(model) 137 #print "model created",model.parameterset[0].value,model.parameterset[1].value 138 # Make all parameters fitting parameters 139 for p in model.parameterset: 140 p.set([-numpy.inf,numpy.inf]) 141 #p.set([-10,10]) 142 Ldata=value.get_data() 143 data=self._concatenateData(Ldata) 144 #print "this data",data 145 #print "data.residuals in createProblem",Ldata[0].residuals 146 #print "data.residuals in createProblem",data.residuals 147 #couple1=(model,Ldata[0]) 148 #mylist.append(couple1) 149 couple=(model,data) 150 mylist.append(couple) 151 #print mylist 152 return mylist 153 #return model,data 63 154 64 155 def fit(self,pars, qmin=None, qmax=None): … … 66 157 Do the fit 67 158 """ 68 #for item in self.fitArrangeList.: 69 70 fitproblem=self.fitArrangeList.values()[0] 71 listdata=[] 72 model = fitproblem.get_model() 73 listdata = fitproblem.get_data() 74 75 parameters = self.set_param(model,pars) 159 160 modelList=self.createProblem(pars) 161 #model,data=self.createProblem() 162 #fitness=assembly.Fitness(model,data) 163 164 problem = park.Assembly(modelList) 165 #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted 166 #problem[0].parameterset['A'].set([0,1000]) 167 #print "problem :",problem[0].parameterset,problem[0].parameterset.fitted 168 fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1)) 169 #return fit.fit(problem) 170 #fit.fit(problem, handler= fitresult.ConsoleUpdate(improvement_delta=0.1)) 171 172 173 def set_model(self,model,Uid): 174 """ Set model """ 175 176 if self.fitArrangeList.has_key(Uid): 177 self.fitArrangeList[Uid].set_model(model) 178 else: 179 fitproblem= FitArrange() 180 fitproblem.set_model(model) 181 self.fitArrangeList[Uid]=fitproblem 182 183 def set_data(self,data,Uid): 184 """ Receive plottable and create a list of data to fit""" 185 data=Data(data) 186 if self.fitArrangeList.has_key(Uid): 187 self.fitArrangeList[Uid].add_data(data) 188 else: 189 fitproblem= FitArrange() 190 fitproblem.add_data(data) 191 self.fitArrangeList[Uid]=fitproblem 192 193 def get_model(self,Uid): 194 """ return list of data""" 195 return self.fitArrangeList[Uid] 196 197 def set_param(self,model, pars): 198 """ Recieve a dictionary of parameter and save it """ 199 parameters=[] 200 if model==None: 201 raise ValueError, "Cannot set parameters for empty model" 202 else: 203 #for key ,value in pars: 204 for key, value in pars.iteritems(): 205 param = Parameter(model, key, value) 206 parameters.append(param) 207 return parameters 208 209 def add_constraint(self, constraint): 210 """ User specify contraint to fit """ 211 self.constraint = str(constraint) 212 213 def get_constraint(self): 214 """ return the contraint value """ 215 return self.constraint 216 217 def set_constraint(self,constraint): 218 """ 219 receive a string as a constraint 220 @param constraint: a string used to constraint some parameters to get a 221 specific value 222 """ 223 self.constraint= constraint 224 def _concatenateData(self, listdata=[]): 225 """ concatenate each fields of all Data contains ins listdata 226 return data 227 """ 76 228 if listdata==[]: 77 229 raise ValueError, " data list missing" 78 230 else: 79 # Do the fit with more than one data set and one model80 231 xtemp=[] 81 232 ytemp=[] 82 233 dytemp=[] 83 234 resid=[] 235 resid_deriv=[] 236 84 237 for data in listdata: 85 238 for i in range(len(data.x)): … … 92 245 if not data.dy[i] in dytemp: 93 246 dytemp.append(data.dy[i]) 94 if qmin==None: 95 qmin= min(xtemp) 96 if qmax==None: 97 qmax= max(xtemp) 98 chisqr, out, cov = fitHelper(model,parameters, xtemp,ytemp, dytemp ,qmin,qmax) 99 return chisqr, out, cov 100 101 def set_model(self,model,Uid): 102 """ Set model """ 103 if self.fitArrangeList.has_key(Uid): 104 self.fitArrangeList[Uid].set_model(model) 105 else: 106 fitproblem= FitArrange() 107 fitproblem.set_model(model) 108 self.fitArrangeList[Uid]=fitproblem 109 110 def set_data(self,data,Uid): 111 """ Receive plottable and create a list of data to fit""" 112 113 if self.fitArrangeList.has_key(Uid): 114 self.fitArrangeList[Uid].add_data(data) 115 else: 116 fitproblem= FitArrange() 117 fitproblem.add_data(data) 118 self.fitArrangeList[Uid]=fitproblem 119 120 def get_model(self,Uid): 121 """ return list of data""" 122 return self.fitArrangeList[Uid] 123 124 def set_param(self,model, pars): 125 """ Recieve a dictionary of parameter and save it """ 126 parameters=[] 127 if model==None: 128 raise ValueError, "Cannot set parameters for empty model" 129 else: 130 #for key ,value in pars: 131 for key, value in pars.iteritems(): 132 param = Parameter(model, key, value) 133 parameters.append(param) 134 return parameters 135 136 def add_constraint(self, constraint): 137 """ User specify contraint to fit """ 138 self.constraint = str(constraint) 139 140 def get_constraint(self): 141 """ return the contraint value """ 142 return self.constraint 143 144 def set_constraint(self,constraint): 145 """ 146 receive a string as a constraint 147 @param constraint: a string used to constraint some parameters to get a 148 specific value 149 """ 150 self.constraint= constraint 151 152 153 154 247 248 249 newplottable= Data1D(xtemp,ytemp,None,dytemp) 250 newdata=Data(newplottable) 251 252 #print "this is new data",newdata.dy 253 return newdata 155 254 class Parameter: 156 255 """ … … 174 273 """ 175 274 return self.model.getParam(self.name) 176 177 class FitHelper: 178 179 def __init__(self,model, pars, x, y, err_y ,qmin=None, qmax=None): 180 self.x = x 181 self.y = y 182 self.model = model 183 self.err_y = err_y 184 self.qmin = qmin 185 self.qmax= qmax 186 self.pars = pars 187 188 def __call__(self, params): 189 i = 0 190 for p in self.pars: 191 p.set(params[i]) 192 i += 1 193 194 residuals = [] 195 for j in range(len(self.x)): 196 if self.x[j]>self.qmin and self.x[j]<self.qmax: 197 residuals.append( ( self.y[j] - self.model.runXY(self.x[j]) ) / self.err_y[j] ) 198 199 return residuals 200 201 202 203 def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None): 204 """ 205 Fit function 206 @param model: sans model object 207 @param pars: list of parameters 208 @param x: vector of x data 209 @param y: vector of y data 210 @param err_y: vector of y errors 211 """ 212 213 f = FitHelper(model, pars, x, y, err_y ,qmin, qmax) 214 215 def ff(params): 216 """ 217 Calculates the vector of residuals for each point 218 in y for a given set of input parameters. 219 @param params: list of parameter values 220 @return: vector of residuals 221 """ 222 i = 0 223 for p in pars: 224 p.set(params[i]) 225 i += 1 226 227 residuals = [] 228 for j in range(len(x)): 229 if x[j]>qmin and x[j]<qmax: 230 residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] ) 231 232 return residuals 233 234 def chi2(params): 235 """ 236 Calculates chi^2 237 @param params: list of parameter values 238 @return: chi^2 239 """ 240 sum = 0 241 res = f(params) 242 for item in res: 243 sum += item*item 244 return sum 245 246 p = [param() for param in pars] 247 out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True) 248 print info, mesg, success 249 # Calculate chi squared 250 if len(pars)>1: 251 chisqr = chi2(out) 252 elif len(pars)==1: 253 chisqr = chi2([out]) 254 255 return chisqr, out, cov_x 275 256 276 257 277 … … 265 285 data1.name = "data1" 266 286 load.load_data(data1) 267 Fit =Fitting()268 269 from LineModel import LineModel287 fitter =ParkFit() 288 289 from sans.guitools.LineModel import LineModel 270 290 model = LineModel() 271 Fit.set_model(model,1) 272 Fit.set_data(data1,1) 273 274 chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None) 275 print"fit only one data",chisqr, out, cov 276 277 # test fit with 2 data and one model 278 Fit =Fitting() 279 Fit.set_model(model,2 ) 280 load.set_filename("testdata1.txt") 281 load.set_values() 282 data2 = Data1D(x=[], y=[], dx=None,dy=None) 283 data2.name = "data2" 284 285 load.load_data(data2) 286 Fit.set_data(data2,2) 287 288 load.set_filename("testdata2.txt") 289 load.set_values() 290 data3 = Data1D(x=[], y=[], dx=None,dy=None) 291 data3.name = "data2" 292 load.load_data(data3) 293 Fit.set_data(data3,2) 294 chisqr, out, cov=Fit.fit({'A':2,'B':1},None,None) 295 print"fit two data",chisqr, out, cov 296 291 fitter.set_model(model,1) 292 fitter.set_data(data1,1) 293 294 print"PARK fit result \n",fitter.fit({'A':2,'B':1},None,None) 295 296 297 298
Note: See TracChangeset
for help on using the changeset viewer.