Changeset acb8788 in sasview
- Timestamp:
- May 21, 2008 11:11:48 AM (17 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, costrafo411, magnetic_scatt, release-4.1.1, release-4.1.2, release-4.2.2, release_4.0.1, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- c14c503
- Parents:
- 93c2ef2
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
park_integration/test/FittingModule.py
r4a0536a racb8788 1 1 #class Fitting 2 from sans.guitools.fittings import Parameter 3 import sans.guitools.fittings 2 from sans.guitools.plottables import Data1D 3 from Loader import Load 4 from scipy import optimize 4 5 class Fitting: 5 6 """ 6 7 Performs the Fit.he user determine what kind of data 7 8 """ 8 def __init__(self,data1,data2=None): 9 self.model = "" 10 self.data1 = data1 11 self.data2= data2 9 def __init__(self,data=[]): 10 #self.model is a list of all models to fit 11 self.model=[] 12 #the list of all data to fit 13 self.data = data 14 #dictionary of models parameters 15 self.parameters={} 12 16 self.contraint =None 17 self.residuals=[] 13 18 14 19 def fit_engine(self): … … 17 22 """ 18 23 return True 19 def fit(self ):24 def fit(self,pars, qmin=None, qmax=None): 20 25 """ 21 26 Do the fit 22 27 """ 23 24 #Display the fittings values 25 self.default_A = self.model.getParam('A') 26 self.default_B = self.model.getParam('B') 27 self.cstA = Parameter(self.model, 'A', self.default_A) 28 self.cstB = Parameter(self.model, 'B', self.default_B) 29 xmin= min(self.data1.x) 30 xmax= max(self.data1.x) 31 32 chisqr, out, cov = sans.guitools.fittings.sansfit(self.model, 33 [self.cstA, self.cstB],self.data1.x, self.data1.y,self.data1.dy,xmin,xmax) 34 28 29 # Do the fit with 2 data set and one model 30 numberData= len(self.data) 31 numberModel= len(self.model) 32 if numberData==1 and numberModel==1: 33 if qmin==None: 34 xmin= min(self.data[0].x) 35 if qmax==None: 36 xmax= max(self.data[0].x) 37 38 #chisqr, out, cov = fitHelper(self.model[0],self.data[0],pars,xmin,xmax) 39 chisqr, out, cov =fitHelper(self.model[0], pars, self.data[0].x, 40 self.data[0].y, self.data[0].dy ,xmin,xmax) 41 else:# More than one data to fit with one model 42 xtemp=[] 43 ytemp=[] 44 dytemp=[] 45 for data in self.data: 46 for i in range(len(data.x)): 47 if not data.x[i] in xtemp: 48 xtemp.append(data.x[i]) 49 50 if not data.y[i] in ytemp: 51 ytemp.append(data.y[i]) 52 53 if not data.dy[i] in dytemp: 54 dytemp.append(data.dy[i]) 55 if qmin==None: 56 xmin= min(xtemp) 57 if qmax==None: 58 xmax= max(xtemp) 59 #chisqr, out, cov = fitHelper(self.model[0], 60 #temp,pars,min(temp.x),max(temp.x)) 61 chisqr, out, cov =fitHelper(self.model[0], pars, xtemp, 62 ytemp, dytemp ,xmin,xmax) 35 63 return chisqr, out, cov 36 64 37 65 def set_model(self,model): 38 66 """ Set model """ 39 self.model = model 40 41 42 def set_data(self,x,y,dx,dy): 43 """ Receive values from Loader class and set plottable variables""" 44 self.data1.x = x 45 self.data1.y = y 46 self.data1.dx= dx 47 self.data1.dy= dy 67 self.model.append(model) 68 69 def set_data(self,data): 70 """ Receive plottable and create a list of data to fit""" 71 self.data.append(data) 72 48 73 def get_data(self): 49 """ return data"""50 return self.data 174 """ return list of data""" 75 return self.data 51 76 52 77 def add_contraint(self, contraint): 53 78 """ User specify contraint to fit """ 54 79 self.contraint = str(contraint) 80 55 81 def get_contraint(self): 56 82 """ return the contraint value """ 57 83 return self.contraint 58 59 84 85 def get_residuals(model,data,qmin=None,qmax=None): 86 """ 87 Calculates the vector of residuals for each point 88 in y for a given set of input parameters. 89 @param params: list of parameter values 90 @return: vector of residuals 91 """ 92 residuals = [] 93 94 for j in range(len(data.x)): 95 if data.x[j]> qmin and data.x[j]< qmax: 96 residuals.append( ( data.y[j] - model.runXY(data.x[j]) ) / data.dy[j]) 97 98 return residuals 99 100 101 def chi2(params): 102 """ 103 Calculates chi^2 104 @param params: list of parameter values 105 @return: chi^2 106 """ 107 sum = 0 108 res = get_residuals(params) 109 for item in res: 110 sum += item*item 111 return sum 112 113 114 def residual(self): 115 return self.residuals 116 117 def fitHelper(model,data,pars,qmin=None,qmax=None): 118 """ Do the actual fitting""" 119 120 p = [param() for param in pars] 121 out, cov_x, info, mesg, success = optimize.leastsq(get_residuals, p, full_output=1, warning=True) 122 print info, mesg, success 123 # Calculate chi squared 124 if len(pars)>1: 125 chisqr = self.chi2(out) 126 elif len(pars)==1: 127 chisqr = self.chi2([out]) 128 129 return chisqr, out, cov_x 130 131 class Parameter: 132 """ 133 Class to handle model parameters 134 """ 135 def __init__(self, model, name, value=None): 136 self.model = model 137 self.name = name 138 if not value==None: 139 self.model.setParam(self.name, value) 140 141 def set(self, value): 142 """ 143 Set the value of the parameter 144 """ 145 self.model.setParam(self.name, value) 146 147 def __call__(self): 148 """ 149 Return the current value of the parameter 150 """ 151 return self.model.getParam(self.name) 152 153 def fitHelper(model, pars, x, y, err_y ,qmin=None, qmax=None): 154 """ 155 Fit function 156 @param model: sans model object 157 @param pars: list of parameters 158 @param x: vector of x data 159 @param y: vector of y data 160 @param err_y: vector of y errors 161 """ 162 def f(params): 163 """ 164 Calculates the vector of residuals for each point 165 in y for a given set of input parameters. 166 @param params: list of parameter values 167 @return: vector of residuals 168 """ 169 i = 0 170 for p in pars: 171 p.set(params[i]) 172 i += 1 173 174 residuals = [] 175 for j in range(len(x)): 176 if x[j]>qmin and x[j]<qmax: 177 residuals.append( ( y[j] - model.runXY(x[j]) ) / err_y[j] ) 178 179 return residuals 180 181 def chi2(params): 182 """ 183 Calculates chi^2 184 @param params: list of parameter values 185 @return: chi^2 186 """ 187 sum = 0 188 res = f(params) 189 for item in res: 190 sum += item*item 191 return sum 192 193 p = [param() for param in pars] 194 out, cov_x, info, mesg, success = optimize.leastsq(f, p, full_output=1, warning=True) 195 print info, mesg, success 196 # Calculate chi squared 197 if len(pars)>1: 198 chisqr = chi2(out) 199 elif len(pars)==1: 200 chisqr = chi2([out]) 201 202 return chisqr, out, cov_x 203 204 205 if __name__ == "__main__": 206 load= Load() 207 208 # test fit one data set one model 209 load.set_filename("testdata_line.txt") 210 load.set_values() 211 data1 = Data1D(x=[], y=[], dx=None,dy=None) 212 data1.name = "data1" 213 load.load_data(data1) 214 Fit =Fitting() 215 Fit.set_data(data1) 216 from sans.guitools.LineModel import LineModel 217 model = LineModel() 218 Fit.set_model(model) 219 220 default_A = model.getParam('A') 221 default_B = model.getParam('B') 222 cstA = Parameter(model, 'A', default_A) 223 cstB = Parameter(model, 'B', default_B) 224 225 chisqr, out, cov=Fit.fit([cstA,cstB],None,None) 226 print"fit only one data",chisqr, out, cov 227 228 # test fit with 2 data and one model 229 load.set_filename("testdata1.txt") 230 load.set_values() 231 data2 = Data1D(x=[], y=[], dx=None,dy=None) 232 data2.name = "data2" 233 234 load.load_data(data2) 235 Fit.set_data(data2) 236 237 load.set_filename("testdata2.txt") 238 load.set_values() 239 data3 = Data1D(x=[], y=[], dx=None,dy=None) 240 data3.name = "data2" 241 load.load_data(data3) 242 Fit.set_data(data3) 243 chisqr, out, cov=Fit.fit([cstA,cstB],None,None) 244 print"fit two data",chisqr, out, cov
Note: See TracChangeset
for help on using the changeset viewer.