source: sasview/park_integration/AbstractFitEngine.py @ c6d3301

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c6d3301 was c6d3301, checked in by Jae Cho <jhjcho@…>, 14 years ago

set dy = 1 only when all dy ==0

  • Property mode set to 100644
File size: 24.1 KB
Line 
1
2
3import logging, sys
4import park,numpy,math, copy
5from DataLoader.data_info import Data1D
6from DataLoader.data_info import Data2D
7
8class SansParameter(park.Parameter):
9    """
10    SANS model parameters for use in the PARK fitting service.
11    The parameter attribute value is redirected to the underlying
12    parameter value in the SANS model.
13    """
14    def __init__(self, name, model):
15        """
16        :param name: the name of the model parameter
17        :param model: the sans model to wrap as a park model
18       
19        """
20        self._model, self._name = model,name
21        #set the value for the parameter of the given name
22        self.set(model.getParam(name))
23         
24    def _getvalue(self):
25        """
26        override the _getvalue of park parameter
27       
28        :return value the parameter associates with self.name
29       
30        """
31        return self._model.getParam(self.name)
32   
33    def _setvalue(self,value):
34        """
35        override the _setvalue pf park parameter
36       
37        :param value: the value to set on a given parameter
38       
39        """
40        self._model.setParam(self.name, value)
41       
42    value = property(_getvalue,_setvalue)
43   
44    def _getrange(self):
45        """
46        Override _getrange of park parameter
47        return the range of parameter
48        """
49        #if not  self.name in self._model.getDispParamList():
50        lo,hi = self._model.details[self.name][1:3]
51        if lo is None: lo = -numpy.inf
52        if hi is None: hi = numpy.inf
53        #else:
54            #lo,hi = self._model.details[self.name][1:]
55            #if lo is None: lo = -numpy.inf
56            #if hi is None: hi = numpy.inf
57        if lo >= hi:
58            raise ValueError,"wrong fit range for parameters"
59       
60        return lo,hi
61   
62    def _setrange(self,r):
63        """
64        override _setrange of park parameter
65       
66        :param r: the value of the range to set
67       
68        """
69        self._model.details[self.name][1:3] = r
70    range = property(_getrange,_setrange)
71   
72class Model(park.Model):
73    """
74    PARK wrapper for SANS models.
75    """
76    def __init__(self, sans_model, **kw):
77        """
78        :param sans_model: the sans model to wrap using park interface
79       
80        """
81        park.Model.__init__(self, **kw)
82        self.model = sans_model
83        self.name = sans_model.name
84        #list of parameters names
85        self.sansp = sans_model.getParamList()
86        #list of park parameter
87        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
88        #list of parameterset
89        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
90        self.pars=[]
91 
92    def getParams(self,fitparams):
93        """
94        return a list of value of paramter to fit
95       
96        :param fitparams: list of paramaters name to fit
97       
98        """
99        list=[]
100        self.pars=[]
101        self.pars=fitparams
102        for item in fitparams:
103            for element in self.parkp:
104                 if element.name ==str(item):
105                     list.append(element.value)
106        return list
107   
108    def setParams(self,paramlist, params):
109        """
110        Set value for parameters to fit
111       
112        :param params: list of value for parameters to fit
113       
114        """
115        try:
116            for i in range(len(self.parkp)):
117                for j in range(len(paramlist)):
118                    if self.parkp[i].name==paramlist[j]:
119                        self.parkp[i].value = params[j]
120                        self.model.setParam(self.parkp[i].name,params[j])
121        except:
122            raise
123 
124    def eval(self,x):
125        """
126        override eval method of park model.
127       
128        :param x: the x value used to compute a function
129       
130        """
131        try:
132            return self.model.evalDistribution(x)
133        except:
134            raise
135
136   
137class FitData1D(Data1D):
138    """
139    Wrapper class  for SANS data
140    FitData1D inherits from DataLoader.data_info.Data1D. Implements
141    a way to get residuals from data.
142    """
143    def __init__(self,x, y,dx= None, dy=None, smearer=None):
144        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
145        """
146        :param smearer: is an object of class QSmearer or SlitSmearer
147           that will smear the theory data (slit smearing or resolution
148           smearing) when set.
149       
150        The proper way to set the smearing object would be to
151        do the following: ::
152       
153            from DataLoader.qsmearing import smear_selection
154            smearer = smear_selection(some_data)
155            fitdata1d = FitData1D( x= [1,3,..,],
156                                    y= [3,4,..,8],
157                                    dx=None,
158                                    dy=[1,2...], smearer= smearer)
159       
160        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
161            Setting it back to None will turn smearing off.
162           
163        """
164       
165        self.smearer = smearer
166        if dy ==None or dy==[]:
167            self.dy= numpy.zeros(len(self.y)) 
168        else:
169            self.dy= numpy.asarray(dy)
170     
171        # For fitting purposes, replace zero errors by 1 if all of them are zero
172        #TODO: check validity for the rare case where only
173        # all points have zero errors
174        if self.dy.all()==0: self.dy = 1
175       
176        ## Min Q-value
177        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
178        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
179            self.qmin = min(self.x[self.x!=0])
180        else:                             
181            self.qmin= min (self.x)
182        ## Max Q-value
183        self.qmax = max (self.x)
184       
185        # Range used for input to smearing
186        self._qmin_unsmeared = self.qmin
187        self._qmax_unsmeared = self.qmax
188        # Identify the bin range for the unsmeared and smeared spaces
189        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
190        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
191 
192       
193       
194    def setFitRange(self,qmin=None,qmax=None):
195        """ to set the fit range"""
196        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
197        #ToDo: Fix this.
198        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
199            self.qmin = min(self.x[self.x!=0])
200        elif qmin!=None:                       
201            self.qmin = qmin           
202
203        if qmax !=None:
204            self.qmax = qmax
205           
206        # Determine the range needed in unsmeared-Q to cover
207        # the smeared Q range
208        self._qmin_unsmeared = self.qmin
209        self._qmax_unsmeared = self.qmax   
210       
211        self._first_unsmeared_bin = 0
212        self._last_unsmeared_bin  = len(self.x)-1
213       
214        if self.smearer!=None:
215            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
216            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
217            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
218           
219        # Identify the bin range for the unsmeared and smeared spaces
220        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
221        self.idx = self.idx & (self.dy!=0)   ## zero error can not participate for fitting
222        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
223 
224       
225    def getFitRange(self):
226        """
227        return the range of data.x to fit
228        """
229        return self.qmin, self.qmax
230       
231    def residuals(self, fn):
232        """
233        Compute residuals.
234       
235        If self.smearer has been set, use if to smear
236        the data before computing chi squared.
237       
238        :param fn: function that return model value
239       
240        :return: residuals
241       
242        """
243        # Compute theory data f(x)
244        fx= numpy.zeros(len(self.x))
245        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
246       
247        ## Smear theory data
248        if self.smearer is not None:
249            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
250       
251        ## Sanity check
252        if numpy.size(self.dy)!= numpy.size(fx):
253            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
254                                                                              numpy.size(fx))
255                                                                             
256        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
257     
258    def residuals_deriv(self, model, pars=[]):
259        """
260        :return: residuals derivatives .
261       
262        :note: in this case just return empty array
263       
264        """
265        return []
266   
267   
268class FitData2D(Data2D):
269    """ Wrapper class  for SANS data """
270    def __init__(self,sans_data2d ,data=None, err_data=None):
271        Data2D.__init__(self, data= data, err_data= err_data)
272        """
273        Data can be initital with a data (sans plottable)
274        or with vectors.
275        """
276        self.res_err_image=[]
277        self.index_model=[]
278        self.qmin= None
279        self.qmax= None
280        self.smearer = None
281        self.set_data(sans_data2d )
282
283       
284    def set_data(self, sans_data2d, qmin=None, qmax=None ):
285        """
286        Determine the correct qx_data and qy_data within range to fit
287        """
288        self.data     = sans_data2d.data
289        self.err_data = sans_data2d.err_data
290        self.qx_data = sans_data2d.qx_data
291        self.qy_data = sans_data2d.qy_data
292        self.mask       = sans_data2d.mask
293
294        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
295        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
296       
297        ## fitting range
298        if qmin == None:
299            self.qmin = 1e-16
300        if qmax == None:
301            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
302        ## new error image for fitting purpose
303        if self.err_data== None or self.err_data ==[]:
304            self.res_err_data= numpy.ones(len(self.data))
305        else:
306            self.res_err_data = copy.deepcopy(self.err_data)
307        #self.res_err_data[self.res_err_data==0]=1
308       
309        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
310       
311        # Note: mask = True: for MASK while mask = False for NOT to mask
312        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
313        self.index_model = (self.index_model) & (self.mask)
314        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
315       
316    def set_smearer(self,smearer): 
317        """
318        Set smearer
319        """
320        if smearer == None:
321            return
322        self.smearer = smearer
323        self.smearer.set_index(self.index_model)
324        self.smearer.get_data()
325
326    def setFitRange(self,qmin=None,qmax=None):
327        """ to set the fit range"""
328        if qmin==0.0:
329            self.qmin = 1e-16
330        elif qmin!=None:                       
331            self.qmin = qmin           
332        if qmax!=None:
333            self.qmax= qmax       
334        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
335        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
336        self.index_model = (self.index_model) &(self.mask)
337        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
338        self.index_model = (self.index_model) & (self.res_err_data!=0)
339       
340    def getFitRange(self):
341        """
342        return the range of data.x to fit
343        """
344        return self.qmin, self.qmax
345     
346    def residuals(self, fn): 
347        """
348        return the residuals
349        """ 
350        if self.smearer != None:
351            fn.set_index(self.index_model)
352            # Get necessary data from self.data and set the data for smearing
353            fn.get_data()
354
355            gn = fn.get_value() 
356        else:
357            gn = fn([self.qx_data[self.index_model],self.qy_data[self.index_model]])
358        # use only the data point within ROI range
359        res=(self.data[self.index_model] - gn)/self.res_err_data[self.index_model]
360        return res
361       
362    def residuals_deriv(self, model, pars=[]):
363        """
364        :return: residuals derivatives .
365       
366        :note: in this case just return empty array
367       
368        """
369        return []
370   
371class FitAbort(Exception):
372    """
373    Exception raise to stop the fit
374    """
375    #print"Creating fit abort Exception"
376
377
378class SansAssembly:
379    """
380    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
381    """
382    def __init__(self, paramlist, model=None , data=None, fitresult=None,
383                 handler=None, curr_thread=None):
384        """
385        :param Model: the model wrapper fro sans -model
386        :param Data: the data wrapper for sans data
387       
388        """
389        self.model = model
390        self.data  = data
391        self.paramlist = paramlist
392        self.curr_thread = curr_thread
393        self.handler = handler
394        self.fitresult = fitresult
395        self.res = []
396        self.func_name = "Functor"
397       
398    def chisq(self, params):
399        """
400        Calculates chi^2
401       
402        :param params: list of parameter values
403       
404        :return: chi^2
405       
406        """
407        sum = 0
408        for item in self.res:
409            sum += item*item
410        if len(self.res)==0:
411            return None
412        return sum/ len(self.res)
413   
414    def __call__(self,params):
415        """
416        Compute residuals
417       
418        :param params: value of parameters to fit
419       
420        """
421        #import thread
422        self.model.setParams(self.paramlist,params)
423        self.res= self.data.residuals(self.model.eval)
424        if self.fitresult is not None and  self.handler is not None:
425            self.fitresult.set_model(model=self.model)
426            fitness = self.chisq(params=params)
427            self.fitresult.set_fitness(fitness=fitness)
428            self.handler.set_result(result=self.fitresult)
429            self.handler.update_fit()
430       
431        #if self.curr_thread != None :
432        #    try:
433        #        self.curr_thread.isquit()
434        #    except:
435        #        raise FitAbort,"stop leastsqr optimizer"   
436        return self.res
437   
438class FitEngine:
439    def __init__(self):
440        """
441        Base class for scipy and park fit engine
442        """
443        #List of parameter names to fit
444        self.paramList=[]
445        #Dictionnary of fitArrange element (fit problems)
446        self.fitArrangeDict={}
447       
448    def _concatenateData(self, listdata=[]):
449        """ 
450        _concatenateData method concatenates each fields of all data
451        contains ins listdata.
452       
453        :param listdata: list of data
454       
455        :return Data: Data is wrapper class for sans plottable. it is created with all parameters
456            of data concatenanted
457           
458        :raise: if listdata is empty  will return None
459        :raise: if data in listdata don't contain dy field ,will create an error
460            during fitting
461           
462        """
463        #TODO: we have to refactor the way we handle data.
464        # We should move away from plottables and move towards the Data1D objects
465        # defined in DataLoader. Data1D allows data manipulations, which should be
466        # used to concatenate.
467        # In the meantime we should switch off the concatenation.
468        #if len(listdata)>1:
469        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
470        #return listdata[0]
471       
472        if listdata==[]:
473            raise ValueError, " data list missing"
474        else:
475            xtemp=[]
476            ytemp=[]
477            dytemp=[]
478            self.mini=None
479            self.maxi=None
480               
481            for item in listdata:
482                data=item.data
483                mini,maxi=data.getFitRange()
484                if self.mini==None and self.maxi==None:
485                    self.mini=mini
486                    self.maxi=maxi
487                else:
488                    if mini < self.mini:
489                        self.mini=mini
490                    if self.maxi < maxi:
491                        self.maxi=maxi
492                       
493                   
494                for i in range(len(data.x)):
495                    xtemp.append(data.x[i])
496                    ytemp.append(data.y[i])
497                    if data.dy is not None and len(data.dy)==len(data.y):   
498                        dytemp.append(data.dy[i])
499                    else:
500                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
501            data= Data(x=xtemp,y=ytemp,dy=dytemp)
502            data.setFitRange(self.mini, self.maxi)
503            return data
504       
505       
506    def set_model(self, model, Uid, pars=[], constraints=[]):
507        """
508        set a model on a given uid in the fit engine.
509       
510        :param model: sans.models type
511        :param Uid: is the key of the fitArrange dictionary where model is
512                saved as a value
513        :param pars: the list of parameters to fit
514        :param constraints: list of
515            tuple (name of parameter, value of parameters)
516            the value of parameter must be a string to constraint 2 different
517            parameters.
518            Example: 
519            we want to fit 2 model M1 and M2 both have parameters A and B.
520            constraints can be:
521             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
522           
523             
524        :note: pars must contains only name of existing model's parameters
525       
526        """
527        if model == None:
528            raise ValueError, "AbstractFitEngine: Need to set model to fit"
529       
530        new_model= model
531        if not issubclass(model.__class__, Model):
532            new_model= Model(model)
533       
534        if len(constraints)>0:
535            for constraint in constraints:
536                name, value = constraint
537                try:
538                    new_model.parameterset[ str(name)].set( str(value) )
539                except:
540                    msg= "Fit Engine: Error occurs when setting the constraint"
541                    msg += " %s for parameter %s "%(value, name)
542                    raise ValueError, msg
543               
544        if len(pars) >0:
545            temp=[]
546            for item in pars:
547                if item in new_model.model.getParamList():
548                    temp.append(item)
549                    self.paramList.append(item)
550                else:
551                   
552                    msg = "wrong parameter %s used"%str(item)
553                    msg += "to set model %s. Choose"%str(new_model.model.name)
554                    msg += "parameter name within %s"%str(new_model.model.getParamList())
555                    raise ValueError,msg
556             
557            #A fitArrange is already created but contains dList only at Uid
558            if self.fitArrangeDict.has_key(Uid):
559                self.fitArrangeDict[Uid].set_model(new_model)
560                self.fitArrangeDict[Uid].pars= pars
561            else:
562            #no fitArrange object has been create with this Uid
563                fitproblem = FitArrange()
564                fitproblem.set_model(new_model)
565                fitproblem.pars= pars
566                self.fitArrangeDict[Uid] = fitproblem
567               
568        else:
569            raise ValueError, "park_integration:missing parameters"
570   
571    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
572        """
573        Receives plottable, creates a list of data to fit,set data
574        in a FitArrange object and adds that object in a dictionary
575        with key Uid.
576       
577        :param data: data added
578        :param Uid: unique key corresponding to a fitArrange object with data
579       
580        """
581        if data.__class__.__name__=='Data2D':
582            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
583        else:
584            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
585       
586        fitdata.setFitRange(qmin=qmin,qmax=qmax)
587        #A fitArrange is already created but contains model only at Uid
588        if self.fitArrangeDict.has_key(Uid):
589            self.fitArrangeDict[Uid].add_data(fitdata)
590        else:
591        #no fitArrange object has been create with this Uid
592            fitproblem= FitArrange()
593            fitproblem.add_data(fitdata)
594            self.fitArrangeDict[Uid]=fitproblem   
595   
596    def get_model(self,Uid):
597        """
598       
599        :param Uid: Uid is key in the dictionary containing the model to return
600       
601        :return:  a model at this uid or None if no FitArrange element was created
602            with this Uid
603           
604        """
605        if self.fitArrangeDict.has_key(Uid):
606            return self.fitArrangeDict[Uid].get_model()
607        else:
608            return None
609   
610    def remove_Fit_Problem(self,Uid):
611        """remove   fitarrange in Uid"""
612        if self.fitArrangeDict.has_key(Uid):
613            del self.fitArrangeDict[Uid]
614           
615    def select_problem_for_fit(self,Uid,value):
616        """
617        select a couple of model and data at the Uid position in dictionary
618        and set in self.selected value to value
619       
620        :param value: the value to allow fitting.
621                can only have the value one or zero
622               
623        """
624        if self.fitArrangeDict.has_key(Uid):
625             self.fitArrangeDict[Uid].set_to_fit( value)
626             
627    def get_problem_to_fit(self,Uid):
628        """
629        return the self.selected value of the fit problem of Uid
630       
631        :param Uid: the Uid of the problem
632       
633        """
634        if self.fitArrangeDict.has_key(Uid):
635             self.fitArrangeDict[Uid].get_to_fit()
636   
637class FitArrange:
638    def __init__(self):
639        """
640        Class FitArrange contains a set of data for a given model
641        to perform the Fit.FitArrange must contain exactly one model
642        and at least one data for the fit to be performed.
643       
644        model: the model selected by the user
645        Ldata: a list of data what the user wants to fit
646           
647        """
648        self.model = None
649        self.dList =[]
650        self.pars=[]
651        #self.selected  is zero when this fit problem is not schedule to fit
652        #self.selected is 1 when schedule to fit
653        self.selected = 0
654       
655    def set_model(self,model):
656        """
657        set_model save a copy of the model
658       
659        :param model: the model being set
660       
661        """
662        self.model = model
663       
664    def add_data(self,data):
665        """
666        add_data fill a self.dList with data to fit
667       
668        :param data: Data to add in the list 
669       
670        """
671        if not data in self.dList:
672            self.dList.append(data)
673           
674    def get_model(self):
675        """
676       
677        :return: saved model
678       
679        """
680        return self.model   
681     
682    def get_data(self):
683        """
684       
685        :return: list of data dList
686       
687        """
688        #return self.dList
689        return self.dList[0] 
690     
691    def remove_data(self,data):
692        """
693        Remove one element from the list
694       
695        :param data: Data to remove from dList
696       
697        """
698        if data in self.dList:
699            self.dList.remove(data)
700           
701    def set_to_fit (self, value=0):
702        """
703        set self.selected to 0 or 1  for other values raise an exception
704       
705        :param value: integer between 0 or 1
706       
707        """
708        self.selected= value
709       
710    def get_to_fit(self):
711        """
712        return self.selected value
713        """
714        return self.selected
Note: See TracBrowser for help on using the repository browser.