source: sasview/park_integration/AbstractFitEngine.py @ b2f4f83

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b2f4f83 was 920a6e5, checked in by Jae Cho <jhjcho@…>, 15 years ago

fixed park fit w/ polydispersion

  • Property mode set to 100644
File size: 21.4 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        #if not  self.name in self._model.getDispParamList():
41        lo,hi = self._model.details[self.name][1:]
42        if lo is None: lo = -numpy.inf
43        if hi is None: hi = numpy.inf
44        #else:
45            #lo,hi = self._model.details[self.name][1:]
46            #if lo is None: lo = -numpy.inf
47            #if hi is None: hi = numpy.inf
48        if lo >= hi:
49            raise ValueError,"wrong fit range for parameters"
50       
51        return lo,hi
52   
53    def _setrange(self,r):
54        """
55            override _setrange of park parameter
56            @param r: the value of the range to set
57        """
58        self._model.details[self.name][1:] = r
59    range = property(_getrange,_setrange)
60   
61class Model(park.Model):
62    """
63        PARK wrapper for SANS models.
64    """
65    def __init__(self, sans_model, **kw):
66        """
67            @param sans_model: the sans model to wrap using park interface
68        """
69        park.Model.__init__(self, **kw)
70        self.model = sans_model
71        self.name = sans_model.name
72        #list of parameters names
73        self.sansp = sans_model.getParamList()
74        #list of park parameter
75        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
76        #list of parameterset
77        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
78        self.pars=[]
79 
80 
81    def getParams(self,fitparams):
82        """
83            return a list of value of paramter to fit
84            @param fitparams: list of paramaters name to fit
85        """
86        list=[]
87        self.pars=[]
88        self.pars=fitparams
89        for item in fitparams:
90            for element in self.parkp:
91                 if element.name ==str(item):
92                     list.append(element.value)
93        return list
94   
95   
96    def setParams(self,paramlist, params):
97        """
98            Set value for parameters to fit
99            @param params: list of value for parameters to fit
100        """
101        try:
102            for i in range(len(self.parkp)):
103                for j in range(len(paramlist)):
104                    if self.parkp[i].name==paramlist[j]:
105                        self.parkp[i].value = params[j]
106                        self.model.setParam(self.parkp[i].name,params[j])
107        except:
108            raise
109 
110    def eval(self,x):
111        """
112            override eval method of park model.
113            @param x: the x value used to compute a function
114        """
115        try:
116                return self.model.evalDistribution(x)
117        except:
118                raise
119
120   
121class FitData1D(object):
122    """ Wrapper class  for SANS data """
123    def __init__(self,sans_data1d, smearer=None):
124        """
125            Data can be initital with a data (sans plottable)
126            or with vectors.
127           
128            self.smearer is an object of class QSmearer or SlitSmearer
129            that will smear the theory data (slit smearing or resolution
130            smearing) when set.
131           
132            The proper way to set the smearing object would be to
133            do the following:
134           
135            from DataLoader.qsmearing import smear_selection
136            fitdata1d = FitData1D(some_data)
137            fitdata1d.smearer = smear_selection(some_data)
138           
139            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
140           
141            Setting it back to None will turn smearing off.
142           
143        """
144       
145        self.smearer = smearer
146     
147        # Initialize from Data1D object
148        self.data=sans_data1d
149        self.x= numpy.array(sans_data1d.x)
150        self.y= numpy.array(sans_data1d.y)
151        self.dx= sans_data1d.dx
152        if sans_data1d.dy ==None or sans_data1d.dy==[]:
153            self.dy= numpy.zeros(len(y)) 
154        else:
155            self.dy= numpy.asarray(sans_data1d.dy)
156     
157        # For fitting purposes, replace zero errors by 1
158        #TODO: check validity for the rare case where only
159        # a few points have zero errors
160        self.dy[self.dy==0]=1
161       
162        ## Min Q-value
163        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
164        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
165            self.qmin = min(self.data.x[self.data.x!=0])
166        else:                             
167            self.qmin= min (self.data.x)
168        ## Max Q-value
169        self.qmax= max (self.data.x)
170       
171        # Range used for input to smearing
172        self._qmin_unsmeared = self.qmin
173        self._qmax_unsmeared = self.qmax
174        # Identify the bin range for the unsmeared and smeared spaces
175        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
176        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
177 
178       
179       
180    def setFitRange(self,qmin=None,qmax=None):
181        """ to set the fit range"""
182        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
183        #ToDo: Fix this.
184        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
185            self.qmin = min(self.data.x[self.data.x!=0])
186        elif qmin!=None:                       
187            self.qmin = qmin           
188
189        if qmax !=None:
190            self.qmax = qmax
191           
192        # Range used for input to smearing
193        self._qmin_unsmeared = self.qmin
194        self._qmax_unsmeared = self.qmax   
195       
196        # Determine the range needed in unsmeared-Q to cover
197        # the smeared Q range
198        #TODO: use the smearing matrix to determine which
199        # bin range to use
200        if self.smearer.__class__.__name__ == 'SlitSmearer':
201            self._qmin_unsmeared = min(self.data.x)
202            self._qmax_unsmeared = max(self.data.x)
203        elif self.smearer.__class__.__name__ == 'QSmearer':
204            # Take 3 sigmas as the offset between smeared and unsmeared space
205            try:
206                offset = 3.0*max(self.smearer.width)
207                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
208                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
209            except:
210                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
211        # Identify the bin range for the unsmeared and smeared spaces
212        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
213        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
214 
215       
216    def getFitRange(self):
217        """
218            @return the range of data.x to fit
219        """
220        return self.qmin, self.qmax
221       
222    def residuals(self, fn):
223        """
224            Compute residuals.
225           
226            If self.smearer has been set, use if to smear
227            the data before computing chi squared.
228           
229            @param fn: function that return model value
230            @return residuals
231        """
232        # Compute theory data f(x)
233        fx= numpy.zeros(len(self.x))
234        _first_bin = None
235        _last_bin  = None
236       
237        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
238       
239       
240        for i_x in range(len(self.x)):
241            if self.idx_unsmeared[i_x]==True:
242                # Identify first and last bin
243                #TODO: refactor this to pass q-values to the smearer
244                # and let it figure out which bin range to use
245                if _first_bin is None:
246                    _first_bin = i_x
247                else:
248                    _last_bin  = i_x
249               
250        ## Smear theory data
251        if self.smearer is not None:
252            fx = self.smearer(fx, _first_bin, _last_bin)
253       
254        ## Sanity check
255        if numpy.size(self.dy)!= numpy.size(fx):
256            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
257                                                                              numpy.size(fx))
258                                                                             
259        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
260     
261 
262       
263    def residuals_deriv(self, model, pars=[]):
264        """
265            @return residuals derivatives .
266            @note: in this case just return empty array
267        """
268        return []
269   
270   
271class FitData2D(object):
272    """ Wrapper class  for SANS data """
273    def __init__(self,sans_data2d):
274        """
275            Data can be initital with a data (sans plottable)
276            or with vectors.
277        """
278        self.data=sans_data2d
279        self.image = sans_data2d.data
280        self.err_image = sans_data2d.err_data
281        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
282                                         [1,len(sans_data2d.x_bins)])
283        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
284                                          [len(sans_data2d.y_bins),1])
285       
286        x = max(self.data.xmin, self.data.xmax)
287        y = max(self.data.ymin, self.data.ymax)
288       
289        ## fitting range
290        self.qmin = 1e-16
291        self.qmax = math.sqrt(x*x +y*y)
292        ## new error image for fitting purpose
293        if self.err_image== None or self.err_image ==[]:
294            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
295        else:
296            self.res_err_image = copy.deepcopy(self.err_image)
297        self.res_err_image[self.err_image==0]=1
298       
299        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
300        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
301       
302       
303    def setFitRange(self,qmin=None,qmax=None):
304        """ to set the fit range"""
305        if qmin==0.0:
306            self.qmin = 1e-16
307        elif qmin!=None:                       
308            self.qmin = qmin           
309        if qmax!=None:
310            self.qmax= qmax
311     
312       
313    def getFitRange(self):
314        """
315            @return the range of data.x to fit
316        """
317        return self.qmin, self.qmax
318     
319    def residuals(self, fn): 
320       
321        res=self.index_model*(self.image - fn([self.x_bins_array,
322                             self.y_bins_array]))/self.res_err_image
323        return res.ravel() 
324       
325 
326    def residuals_deriv(self, model, pars=[]):
327        """
328            @return residuals derivatives .
329            @note: in this case just return empty array
330        """
331        return []
332   
333class FitAbort(Exception):
334    """
335        Exception raise to stop the fit
336    """
337    print"Creating fit abort Exception"
338
339
340class SansAssembly:
341    """
342         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
343    """
344    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
345        """
346            @param Model: the model wrapper fro sans -model
347            @param Data: the data wrapper for sans data
348        """
349        self.model = Model
350        self.data  = Data
351        self.paramlist=paramlist
352        self.curr_thread= curr_thread
353        self.res=[]
354        self.func_name="Functor"
355    def chisq(self, params):
356        """
357            Calculates chi^2
358            @param params: list of parameter values
359            @return: chi^2
360        """
361        sum = 0
362        for item in self.res:
363            sum += item*item
364        if len(self.res)==0:
365            return None
366        return sum/ len(self.res)
367   
368    def __call__(self,params):
369        """
370            Compute residuals
371            @param params: value of parameters to fit
372        """
373        #import thread
374        self.model.setParams(self.paramlist,params)
375        self.res= self.data.residuals(self.model.eval)
376        #if self.curr_thread != None :
377        #    try:
378        #        self.curr_thread.isquit()
379        #    except:
380        #        raise FitAbort,"stop leastsqr optimizer"   
381        return self.res
382   
383class FitEngine:
384    def __init__(self):
385        """
386            Base class for scipy and park fit engine
387        """
388        #List of parameter names to fit
389        self.paramList=[]
390        #Dictionnary of fitArrange element (fit problems)
391        self.fitArrangeDict={}
392       
393    def _concatenateData(self, listdata=[]):
394        """ 
395            _concatenateData method concatenates each fields of all data contains ins listdata.
396            @param listdata: list of data
397            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
398             of data concatenanted
399            @raise: if listdata is empty  will return None
400            @raise: if data in listdata don't contain dy field ,will create an error
401            during fitting
402        """
403        #TODO: we have to refactor the way we handle data.
404        # We should move away from plottables and move towards the Data1D objects
405        # defined in DataLoader. Data1D allows data manipulations, which should be
406        # used to concatenate.
407        # In the meantime we should switch off the concatenation.
408        #if len(listdata)>1:
409        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
410        #return listdata[0]
411       
412        if listdata==[]:
413            raise ValueError, " data list missing"
414        else:
415            xtemp=[]
416            ytemp=[]
417            dytemp=[]
418            self.mini=None
419            self.maxi=None
420               
421            for item in listdata:
422                data=item.data
423                mini,maxi=data.getFitRange()
424                if self.mini==None and self.maxi==None:
425                    self.mini=mini
426                    self.maxi=maxi
427                else:
428                    if mini < self.mini:
429                        self.mini=mini
430                    if self.maxi < maxi:
431                        self.maxi=maxi
432                       
433                   
434                for i in range(len(data.x)):
435                    xtemp.append(data.x[i])
436                    ytemp.append(data.y[i])
437                    if data.dy is not None and len(data.dy)==len(data.y):   
438                        dytemp.append(data.dy[i])
439                    else:
440                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
441            data= Data(x=xtemp,y=ytemp,dy=dytemp)
442            data.setFitRange(self.mini, self.maxi)
443            return data
444       
445       
446    def set_model(self,model,Uid,pars=[]):
447        """
448            set a model on a given uid in the fit engine.
449            @param model: the model to fit
450            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
451            @param pars: the list of parameters to fit
452            @note : pars must contains only name of existing model's paramaters
453        """
454        if len(pars) >0:
455            if model==None:
456                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
457            else:
458                temp=[]
459                for item in pars:
460                    if item in model.model.getParamList():
461                        temp.append(item)
462                        self.paramList.append(item)
463                    else:
464                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
465                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
466                        return
467            #A fitArrange is already created but contains dList only at Uid
468            if self.fitArrangeDict.has_key(Uid):
469                self.fitArrangeDict[Uid].set_model(model)
470                self.fitArrangeDict[Uid].pars= pars
471            else:
472            #no fitArrange object has been create with this Uid
473                fitproblem = FitArrange()
474                fitproblem.set_model(model)
475                fitproblem.pars= pars
476                self.fitArrangeDict[Uid] = fitproblem
477               
478        else:
479            raise ValueError, "park_integration:missing parameters"
480   
481    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
482        """ Receives plottable, creates a list of data to fit,set data
483            in a FitArrange object and adds that object in a dictionary
484            with key Uid.
485            @param data: data added
486            @param Uid: unique key corresponding to a fitArrange object with data
487        """
488        if data.__class__.__name__=='Data2D':
489            fitdata=FitData2D(data)
490        else:
491            fitdata=FitData1D(data, smearer)
492       
493        fitdata.setFitRange(qmin=qmin,qmax=qmax)
494        #A fitArrange is already created but contains model only at Uid
495        if self.fitArrangeDict.has_key(Uid):
496            self.fitArrangeDict[Uid].add_data(fitdata)
497        else:
498        #no fitArrange object has been create with this Uid
499            fitproblem= FitArrange()
500            fitproblem.add_data(fitdata)
501            self.fitArrangeDict[Uid]=fitproblem   
502   
503    def get_model(self,Uid):
504        """
505            @param Uid: Uid is key in the dictionary containing the model to return
506            @return  a model at this uid or None if no FitArrange element was created
507            with this Uid
508        """
509        if self.fitArrangeDict.has_key(Uid):
510            return self.fitArrangeDict[Uid].get_model()
511        else:
512            return None
513   
514    def remove_Fit_Problem(self,Uid):
515        """remove   fitarrange in Uid"""
516        if self.fitArrangeDict.has_key(Uid):
517            del self.fitArrangeDict[Uid]
518           
519    def select_problem_for_fit(self,Uid,value):
520        """
521            select a couple of model and data at the Uid position in dictionary
522            and set in self.selected value to value
523            @param value: the value to allow fitting. can only have the value one or zero
524        """
525        if self.fitArrangeDict.has_key(Uid):
526             self.fitArrangeDict[Uid].set_to_fit( value)
527             
528             
529    def get_problem_to_fit(self,Uid):
530        """
531            return the self.selected value of the fit problem of Uid
532           @param Uid: the Uid of the problem
533        """
534        if self.fitArrangeDict.has_key(Uid):
535             self.fitArrangeDict[Uid].get_to_fit()
536   
537class FitArrange:
538    def __init__(self):
539        """
540            Class FitArrange contains a set of data for a given model
541            to perform the Fit.FitArrange must contain exactly one model
542            and at least one data for the fit to be performed.
543            model: the model selected by the user
544            Ldata: a list of data what the user wants to fit
545           
546        """
547        self.model = None
548        self.dList =[]
549        self.pars=[]
550        #self.selected  is zero when this fit problem is not schedule to fit
551        #self.selected is 1 when schedule to fit
552        self.selected = 0
553       
554    def set_model(self,model):
555        """
556            set_model save a copy of the model
557            @param model: the model being set
558        """
559        self.model = model
560       
561    def add_data(self,data):
562        """
563            add_data fill a self.dList with data to fit
564            @param data: Data to add in the list 
565        """
566        if not data in self.dList:
567            self.dList.append(data)
568           
569    def get_model(self):
570        """ @return: saved model """
571        return self.model   
572     
573    def get_data(self):
574        """ @return:  list of data dList"""
575        #return self.dList
576        return self.dList[0] 
577     
578    def remove_data(self,data):
579        """
580            Remove one element from the list
581            @param data: Data to remove from dList
582        """
583        if data in self.dList:
584            self.dList.remove(data)
585    def set_to_fit (self, value=0):
586        """
587           set self.selected to 0 or 1  for other values raise an exception
588           @param value: integer between 0 or 1
589        """
590        self.selected= value
591       
592    def get_to_fit(self):
593        """
594            @return self.selected value
595        """
596        return self.selected
Note: See TracBrowser for help on using the repository browser.