source: sasview/park_integration/AbstractFitEngine.py @ 7c8b6a5

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 7c8b6a5 was 189be4e, checked in by Jae Cho <jhjcho@…>, 14 years ago

fixed fitting error on theory curve

  • Property mode set to 100644
File size: 24.1 KB
Line 
1
2
3import logging, sys
4import park,numpy,math, copy
5from DataLoader.data_info import Data1D
6from DataLoader.data_info import Data2D
7
8class SansParameter(park.Parameter):
9    """
10    SANS model parameters for use in the PARK fitting service.
11    The parameter attribute value is redirected to the underlying
12    parameter value in the SANS model.
13    """
14    def __init__(self, name, model):
15        """
16        :param name: the name of the model parameter
17        :param model: the sans model to wrap as a park model
18       
19        """
20        self._model, self._name = model,name
21        #set the value for the parameter of the given name
22        self.set(model.getParam(name))
23         
24    def _getvalue(self):
25        """
26        override the _getvalue of park parameter
27       
28        :return value the parameter associates with self.name
29       
30        """
31        return self._model.getParam(self.name)
32   
33    def _setvalue(self,value):
34        """
35        override the _setvalue pf park parameter
36       
37        :param value: the value to set on a given parameter
38       
39        """
40        self._model.setParam(self.name, value)
41       
42    value = property(_getvalue,_setvalue)
43   
44    def _getrange(self):
45        """
46        Override _getrange of park parameter
47        return the range of parameter
48        """
49        #if not  self.name in self._model.getDispParamList():
50        lo,hi = self._model.details[self.name][1:3]
51        if lo is None: lo = -numpy.inf
52        if hi is None: hi = numpy.inf
53        #else:
54            #lo,hi = self._model.details[self.name][1:]
55            #if lo is None: lo = -numpy.inf
56            #if hi is None: hi = numpy.inf
57        if lo >= hi:
58            raise ValueError,"wrong fit range for parameters"
59       
60        return lo,hi
61   
62    def _setrange(self,r):
63        """
64        override _setrange of park parameter
65       
66        :param r: the value of the range to set
67       
68        """
69        self._model.details[self.name][1:3] = r
70    range = property(_getrange,_setrange)
71   
72class Model(park.Model):
73    """
74    PARK wrapper for SANS models.
75    """
76    def __init__(self, sans_model, **kw):
77        """
78        :param sans_model: the sans model to wrap using park interface
79       
80        """
81        park.Model.__init__(self, **kw)
82        self.model = sans_model
83        self.name = sans_model.name
84        #list of parameters names
85        self.sansp = sans_model.getParamList()
86        #list of park parameter
87        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
88        #list of parameterset
89        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
90        self.pars=[]
91 
92    def getParams(self,fitparams):
93        """
94        return a list of value of paramter to fit
95       
96        :param fitparams: list of paramaters name to fit
97       
98        """
99        list=[]
100        self.pars=[]
101        self.pars=fitparams
102        for item in fitparams:
103            for element in self.parkp:
104                 if element.name ==str(item):
105                     list.append(element.value)
106        return list
107   
108    def setParams(self,paramlist, params):
109        """
110        Set value for parameters to fit
111       
112        :param params: list of value for parameters to fit
113       
114        """
115        try:
116            for i in range(len(self.parkp)):
117                for j in range(len(paramlist)):
118                    if self.parkp[i].name==paramlist[j]:
119                        self.parkp[i].value = params[j]
120                        self.model.setParam(self.parkp[i].name,params[j])
121        except:
122            raise
123 
124    def eval(self,x):
125        """
126        override eval method of park model.
127       
128        :param x: the x value used to compute a function
129       
130        """
131        try:
132            return self.model.evalDistribution(x)
133        except:
134            raise
135
136   
137class FitData1D(Data1D):
138    """
139    Wrapper class  for SANS data
140    FitData1D inherits from DataLoader.data_info.Data1D. Implements
141    a way to get residuals from data.
142    """
143    def __init__(self,x, y,dx= None, dy=None, smearer=None):
144        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
145        """
146        :param smearer: is an object of class QSmearer or SlitSmearer
147           that will smear the theory data (slit smearing or resolution
148           smearing) when set.
149       
150        The proper way to set the smearing object would be to
151        do the following: ::
152       
153            from DataLoader.qsmearing import smear_selection
154            smearer = smear_selection(some_data)
155            fitdata1d = FitData1D( x= [1,3,..,],
156                                    y= [3,4,..,8],
157                                    dx=None,
158                                    dy=[1,2...], smearer= smearer)
159       
160        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
161            Setting it back to None will turn smearing off.
162           
163        """
164        self.smearer = smearer
165       
166        # Check error bar; if no error bar found, set it constant(=1)
167        # TODO: Should provide an option for users to set it like percent, constant, or dy data
168        if dy ==None or dy==[] or dy.all()==0:
169            self.dy= numpy.ones(len(y)) 
170        else:
171            self.dy= numpy.asarray(dy).copy()
172
173        ## Min Q-value
174        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
175        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
176            self.qmin = min(self.x[self.x!=0])
177        else:                             
178            self.qmin= min (self.x)
179        ## Max Q-value
180        self.qmax = max (self.x)
181       
182        # Range used for input to smearing
183        self._qmin_unsmeared = self.qmin
184        self._qmax_unsmeared = self.qmax
185        # Identify the bin range for the unsmeared and smeared spaces
186        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
187        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
188 
189       
190       
191    def setFitRange(self,qmin=None,qmax=None):
192        """ to set the fit range"""
193        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
194        # ToDo: Find better way to do it.
195        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
196            self.qmin = min(self.x[self.x!=0])
197        elif qmin!=None:                       
198            self.qmin = qmin           
199
200        if qmax !=None:
201            self.qmax = qmax
202           
203        # Determine the range needed in unsmeared-Q to cover
204        # the smeared Q range
205        self._qmin_unsmeared = self.qmin
206        self._qmax_unsmeared = self.qmax   
207       
208        self._first_unsmeared_bin = 0
209        self._last_unsmeared_bin  = len(self.x)-1
210       
211        if self.smearer!=None:
212            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
213            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
214            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
215           
216        # Identify the bin range for the unsmeared and smeared spaces
217        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
218        self.idx = self.idx & (self.dy!=0)   ## zero error can not participate for fitting
219        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
220 
221       
222    def getFitRange(self):
223        """
224        return the range of data.x to fit
225        """
226        return self.qmin, self.qmax
227       
228    def residuals(self, fn):
229        """
230        Compute residuals.
231       
232        If self.smearer has been set, use if to smear
233        the data before computing chi squared.
234       
235        :param fn: function that return model value
236       
237        :return: residuals
238       
239        """
240        # Compute theory data f(x)
241        fx= numpy.zeros(len(self.x))
242        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
243       
244        ## Smear theory data
245        if self.smearer is not None:
246            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
247
248        ## Sanity check
249        if numpy.size(self.dy)!= numpy.size(fx):
250            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
251                                                                              numpy.size(fx))
252                                                                             
253        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
254     
255    def residuals_deriv(self, model, pars=[]):
256        """
257        :return: residuals derivatives .
258       
259        :note: in this case just return empty array
260       
261        """
262        return []
263   
264   
265class FitData2D(Data2D):
266    """ Wrapper class  for SANS data """
267    def __init__(self,sans_data2d ,data=None, err_data=None):
268        Data2D.__init__(self, data= data, err_data= err_data)
269        """
270        Data can be initital with a data (sans plottable)
271        or with vectors.
272        """
273        self.res_err_image=[]
274        self.index_model=[]
275        self.qmin= None
276        self.qmax= None
277        self.smearer = None
278        self.set_data(sans_data2d )
279
280       
281    def set_data(self, sans_data2d, qmin=None, qmax=None ):
282        """
283        Determine the correct qx_data and qy_data within range to fit
284        """
285        self.data     = sans_data2d.data
286        self.err_data = sans_data2d.err_data
287        self.qx_data = sans_data2d.qx_data
288        self.qy_data = sans_data2d.qy_data
289        self.mask       = sans_data2d.mask
290
291        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
292        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
293       
294        ## fitting range
295        if qmin == None:
296            self.qmin = 1e-16
297        if qmax == None:
298            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
299        ## new error image for fitting purpose
300        if self.err_data== None or self.err_data ==[]:
301            self.res_err_data= numpy.ones(len(self.data))
302        else:
303            self.res_err_data = copy.deepcopy(self.err_data)
304        #self.res_err_data[self.res_err_data==0]=1
305       
306        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
307       
308        # Note: mask = True: for MASK while mask = False for NOT to mask
309        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
310        self.index_model = (self.index_model) & (self.mask)
311        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
312       
313    def set_smearer(self,smearer): 
314        """
315        Set smearer
316        """
317        if smearer == None:
318            return
319        self.smearer = smearer
320        self.smearer.set_index(self.index_model)
321        self.smearer.get_data()
322
323    def setFitRange(self,qmin=None,qmax=None):
324        """ to set the fit range"""
325        if qmin==0.0:
326            self.qmin = 1e-16
327        elif qmin!=None:                       
328            self.qmin = qmin           
329        if qmax!=None:
330            self.qmax= qmax       
331        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
332        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
333        self.index_model = (self.index_model) &(self.mask)
334        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
335        self.index_model = (self.index_model) & (self.res_err_data!=0)
336       
337    def getFitRange(self):
338        """
339        return the range of data.x to fit
340        """
341        return self.qmin, self.qmax
342     
343    def residuals(self, fn): 
344        """
345        return the residuals
346        """ 
347        if self.smearer != None:
348            fn.set_index(self.index_model)
349            # Get necessary data from self.data and set the data for smearing
350            fn.get_data()
351
352            gn = fn.get_value() 
353        else:
354            gn = fn([self.qx_data[self.index_model],self.qy_data[self.index_model]])
355        # use only the data point within ROI range
356        res=(self.data[self.index_model] - gn)/self.res_err_data[self.index_model]
357        return res
358       
359    def residuals_deriv(self, model, pars=[]):
360        """
361        :return: residuals derivatives .
362       
363        :note: in this case just return empty array
364       
365        """
366        return []
367   
368class FitAbort(Exception):
369    """
370    Exception raise to stop the fit
371    """
372    #print"Creating fit abort Exception"
373
374
375class SansAssembly:
376    """
377    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
378    """
379    def __init__(self, paramlist, model=None , data=None, fitresult=None,
380                 handler=None, curr_thread=None):
381        """
382        :param Model: the model wrapper fro sans -model
383        :param Data: the data wrapper for sans data
384       
385        """
386        self.model = model
387        self.data  = data
388        self.paramlist = paramlist
389        self.curr_thread = curr_thread
390        self.handler = handler
391        self.fitresult = fitresult
392        self.res = []
393        self.func_name = "Functor"
394       
395    def chisq(self, params):
396        """
397        Calculates chi^2
398       
399        :param params: list of parameter values
400       
401        :return: chi^2
402       
403        """
404        sum = 0
405        for item in self.res:
406            sum += item*item
407        if len(self.res)==0:
408            return None
409        return sum/ len(self.res)
410   
411    def __call__(self,params):
412        """
413        Compute residuals
414       
415        :param params: value of parameters to fit
416       
417        """
418        #import thread
419        self.model.setParams(self.paramlist,params)
420        self.res= self.data.residuals(self.model.eval)
421        if self.fitresult is not None and  self.handler is not None:
422            self.fitresult.set_model(model=self.model)
423            fitness = self.chisq(params=params)
424            self.fitresult.set_fitness(fitness=fitness)
425            self.handler.set_result(result=self.fitresult)
426            self.handler.update_fit()
427       
428        #if self.curr_thread != None :
429        #    try:
430        #        self.curr_thread.isquit()
431        #    except:
432        #        raise FitAbort,"stop leastsqr optimizer"   
433        return self.res
434   
435class FitEngine:
436    def __init__(self):
437        """
438        Base class for scipy and park fit engine
439        """
440        #List of parameter names to fit
441        self.paramList=[]
442        #Dictionnary of fitArrange element (fit problems)
443        self.fitArrangeDict={}
444       
445    def _concatenateData(self, listdata=[]):
446        """ 
447        _concatenateData method concatenates each fields of all data
448        contains ins listdata.
449       
450        :param listdata: list of data
451       
452        :return Data: Data is wrapper class for sans plottable. it is created with all parameters
453            of data concatenanted
454           
455        :raise: if listdata is empty  will return None
456        :raise: if data in listdata don't contain dy field ,will create an error
457            during fitting
458           
459        """
460        #TODO: we have to refactor the way we handle data.
461        # We should move away from plottables and move towards the Data1D objects
462        # defined in DataLoader. Data1D allows data manipulations, which should be
463        # used to concatenate.
464        # In the meantime we should switch off the concatenation.
465        #if len(listdata)>1:
466        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
467        #return listdata[0]
468       
469        if listdata==[]:
470            raise ValueError, " data list missing"
471        else:
472            xtemp=[]
473            ytemp=[]
474            dytemp=[]
475            self.mini=None
476            self.maxi=None
477               
478            for item in listdata:
479                data=item.data
480                mini,maxi=data.getFitRange()
481                if self.mini==None and self.maxi==None:
482                    self.mini=mini
483                    self.maxi=maxi
484                else:
485                    if mini < self.mini:
486                        self.mini=mini
487                    if self.maxi < maxi:
488                        self.maxi=maxi
489                       
490                   
491                for i in range(len(data.x)):
492                    xtemp.append(data.x[i])
493                    ytemp.append(data.y[i])
494                    if data.dy is not None and len(data.dy)==len(data.y):   
495                        dytemp.append(data.dy[i])
496                    else:
497                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
498            data= Data(x=xtemp,y=ytemp,dy=dytemp)
499            data.setFitRange(self.mini, self.maxi)
500            return data
501       
502       
503    def set_model(self, model, Uid, pars=[], constraints=[]):
504        """
505        set a model on a given uid in the fit engine.
506       
507        :param model: sans.models type
508        :param Uid: is the key of the fitArrange dictionary where model is
509                saved as a value
510        :param pars: the list of parameters to fit
511        :param constraints: list of
512            tuple (name of parameter, value of parameters)
513            the value of parameter must be a string to constraint 2 different
514            parameters.
515            Example: 
516            we want to fit 2 model M1 and M2 both have parameters A and B.
517            constraints can be:
518             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
519           
520             
521        :note: pars must contains only name of existing model's parameters
522       
523        """
524        if model == None:
525            raise ValueError, "AbstractFitEngine: Need to set model to fit"
526       
527        new_model= model
528        if not issubclass(model.__class__, Model):
529            new_model= Model(model)
530       
531        if len(constraints)>0:
532            for constraint in constraints:
533                name, value = constraint
534                try:
535                    new_model.parameterset[ str(name)].set( str(value) )
536                except:
537                    msg= "Fit Engine: Error occurs when setting the constraint"
538                    msg += " %s for parameter %s "%(value, name)
539                    raise ValueError, msg
540               
541        if len(pars) >0:
542            temp=[]
543            for item in pars:
544                if item in new_model.model.getParamList():
545                    temp.append(item)
546                    self.paramList.append(item)
547                else:
548                   
549                    msg = "wrong parameter %s used"%str(item)
550                    msg += "to set model %s. Choose"%str(new_model.model.name)
551                    msg += "parameter name within %s"%str(new_model.model.getParamList())
552                    raise ValueError,msg
553             
554            #A fitArrange is already created but contains dList only at Uid
555            if self.fitArrangeDict.has_key(Uid):
556                self.fitArrangeDict[Uid].set_model(new_model)
557                self.fitArrangeDict[Uid].pars= pars
558            else:
559            #no fitArrange object has been create with this Uid
560                fitproblem = FitArrange()
561                fitproblem.set_model(new_model)
562                fitproblem.pars= pars
563                self.fitArrangeDict[Uid] = fitproblem
564               
565        else:
566            raise ValueError, "park_integration:missing parameters"
567   
568    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
569        """
570        Receives plottable, creates a list of data to fit,set data
571        in a FitArrange object and adds that object in a dictionary
572        with key Uid.
573       
574        :param data: data added
575        :param Uid: unique key corresponding to a fitArrange object with data
576       
577        """
578        if data.__class__.__name__=='Data2D':
579            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
580        else:
581            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
582       
583        fitdata.setFitRange(qmin=qmin,qmax=qmax)
584        #A fitArrange is already created but contains model only at Uid
585        if self.fitArrangeDict.has_key(Uid):
586            self.fitArrangeDict[Uid].add_data(fitdata)
587        else:
588        #no fitArrange object has been create with this Uid
589            fitproblem= FitArrange()
590            fitproblem.add_data(fitdata)
591            self.fitArrangeDict[Uid]=fitproblem   
592   
593    def get_model(self,Uid):
594        """
595       
596        :param Uid: Uid is key in the dictionary containing the model to return
597       
598        :return:  a model at this uid or None if no FitArrange element was created
599            with this Uid
600           
601        """
602        if self.fitArrangeDict.has_key(Uid):
603            return self.fitArrangeDict[Uid].get_model()
604        else:
605            return None
606   
607    def remove_Fit_Problem(self,Uid):
608        """remove   fitarrange in Uid"""
609        if self.fitArrangeDict.has_key(Uid):
610            del self.fitArrangeDict[Uid]
611           
612    def select_problem_for_fit(self,Uid,value):
613        """
614        select a couple of model and data at the Uid position in dictionary
615        and set in self.selected value to value
616       
617        :param value: the value to allow fitting.
618                can only have the value one or zero
619               
620        """
621        if self.fitArrangeDict.has_key(Uid):
622             self.fitArrangeDict[Uid].set_to_fit( value)
623             
624    def get_problem_to_fit(self,Uid):
625        """
626        return the self.selected value of the fit problem of Uid
627       
628        :param Uid: the Uid of the problem
629       
630        """
631        if self.fitArrangeDict.has_key(Uid):
632             self.fitArrangeDict[Uid].get_to_fit()
633   
634class FitArrange:
635    def __init__(self):
636        """
637        Class FitArrange contains a set of data for a given model
638        to perform the Fit.FitArrange must contain exactly one model
639        and at least one data for the fit to be performed.
640       
641        model: the model selected by the user
642        Ldata: a list of data what the user wants to fit
643           
644        """
645        self.model = None
646        self.dList =[]
647        self.pars=[]
648        #self.selected  is zero when this fit problem is not schedule to fit
649        #self.selected is 1 when schedule to fit
650        self.selected = 0
651       
652    def set_model(self,model):
653        """
654        set_model save a copy of the model
655       
656        :param model: the model being set
657       
658        """
659        self.model = model
660       
661    def add_data(self,data):
662        """
663        add_data fill a self.dList with data to fit
664       
665        :param data: Data to add in the list 
666       
667        """
668        if not data in self.dList:
669            self.dList.append(data)
670           
671    def get_model(self):
672        """
673       
674        :return: saved model
675       
676        """
677        return self.model   
678     
679    def get_data(self):
680        """
681       
682        :return: list of data dList
683       
684        """
685        #return self.dList
686        return self.dList[0] 
687     
688    def remove_data(self,data):
689        """
690        Remove one element from the list
691       
692        :param data: Data to remove from dList
693       
694        """
695        if data in self.dList:
696            self.dList.remove(data)
697           
698    def set_to_fit (self, value=0):
699        """
700        set self.selected to 0 or 1  for other values raise an exception
701       
702        :param value: integer between 0 or 1
703       
704        """
705        self.selected= value
706       
707    def get_to_fit(self):
708        """
709        return self.selected value
710        """
711        return self.selected
Note: See TracBrowser for help on using the repository browser.