source: sasview/park_integration/AbstractFitEngine.py @ b94945d

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b94945d was aa36f96, checked in by Gervaise Alina <gervyh@…>, 14 years ago

working on documentation

  • Property mode set to 100644
File size: 24.0 KB
Line 
1
2
3import logging, sys
4import park,numpy,math, copy
5from DataLoader.data_info import Data1D
6from DataLoader.data_info import Data2D
7
8class SansParameter(park.Parameter):
9    """
10    SANS model parameters for use in the PARK fitting service.
11    The parameter attribute value is redirected to the underlying
12    parameter value in the SANS model.
13    """
14    def __init__(self, name, model):
15        """
16        :param name: the name of the model parameter
17        :param model: the sans model to wrap as a park model
18       
19        """
20        self._model, self._name = model,name
21        #set the value for the parameter of the given name
22        self.set(model.getParam(name))
23         
24    def _getvalue(self):
25        """
26        override the _getvalue of park parameter
27       
28        :return value the parameter associates with self.name
29       
30        """
31        return self._model.getParam(self.name)
32   
33    def _setvalue(self,value):
34        """
35        override the _setvalue pf park parameter
36       
37        :param value: the value to set on a given parameter
38       
39        """
40        self._model.setParam(self.name, value)
41       
42    value = property(_getvalue,_setvalue)
43   
44    def _getrange(self):
45        """
46        Override _getrange of park parameter
47        return the range of parameter
48        """
49        #if not  self.name in self._model.getDispParamList():
50        lo,hi = self._model.details[self.name][1:3]
51        if lo is None: lo = -numpy.inf
52        if hi is None: hi = numpy.inf
53        #else:
54            #lo,hi = self._model.details[self.name][1:]
55            #if lo is None: lo = -numpy.inf
56            #if hi is None: hi = numpy.inf
57        if lo >= hi:
58            raise ValueError,"wrong fit range for parameters"
59       
60        return lo,hi
61   
62    def _setrange(self,r):
63        """
64        override _setrange of park parameter
65       
66        :param r: the value of the range to set
67       
68        """
69        self._model.details[self.name][1:3] = r
70    range = property(_getrange,_setrange)
71   
72class Model(park.Model):
73    """
74    PARK wrapper for SANS models.
75    """
76    def __init__(self, sans_model, **kw):
77        """
78        :param sans_model: the sans model to wrap using park interface
79       
80        """
81        park.Model.__init__(self, **kw)
82        self.model = sans_model
83        self.name = sans_model.name
84        #list of parameters names
85        self.sansp = sans_model.getParamList()
86        #list of park parameter
87        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
88        #list of parameterset
89        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
90        self.pars=[]
91 
92    def getParams(self,fitparams):
93        """
94        return a list of value of paramter to fit
95       
96        :param fitparams: list of paramaters name to fit
97       
98        """
99        list=[]
100        self.pars=[]
101        self.pars=fitparams
102        for item in fitparams:
103            for element in self.parkp:
104                 if element.name ==str(item):
105                     list.append(element.value)
106        return list
107   
108    def setParams(self,paramlist, params):
109        """
110        Set value for parameters to fit
111       
112        :param params: list of value for parameters to fit
113       
114        """
115        try:
116            for i in range(len(self.parkp)):
117                for j in range(len(paramlist)):
118                    if self.parkp[i].name==paramlist[j]:
119                        self.parkp[i].value = params[j]
120                        self.model.setParam(self.parkp[i].name,params[j])
121        except:
122            raise
123 
124    def eval(self,x):
125        """
126        override eval method of park model.
127       
128        :param x: the x value used to compute a function
129       
130        """
131        try:
132            return self.model.evalDistribution(x)
133        except:
134            raise
135
136   
137class FitData1D(Data1D):
138    """
139    Wrapper class  for SANS data
140    FitData1D inherits from DataLoader.data_info.Data1D. Implements
141    a way to get residuals from data.
142    """
143    def __init__(self,x, y,dx= None, dy=None, smearer=None):
144        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
145        """
146        :param smearer: is an object of class QSmearer or SlitSmearer
147           that will smear the theory data (slit smearing or resolution
148           smearing) when set.
149       
150        The proper way to set the smearing object would be to
151        do the following: ::
152       
153            from DataLoader.qsmearing import smear_selection
154            smearer = smear_selection(some_data)
155            fitdata1d = FitData1D( x= [1,3,..,],
156                                    y= [3,4,..,8],
157                                    dx=None,
158                                    dy=[1,2...], smearer= smearer)
159       
160        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
161            Setting it back to None will turn smearing off.
162           
163        """
164       
165        self.smearer = smearer
166        if dy ==None or dy==[]:
167            self.dy= numpy.zeros(len(self.y)) 
168        else:
169            self.dy= numpy.asarray(dy)
170     
171        # For fitting purposes, replace zero errors by 1
172        #TODO: check validity for the rare case where only
173        # a few points have zero errors
174        self.dy[self.dy==0]=1
175       
176        ## Min Q-value
177        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
178        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
179            self.qmin = min(self.x[self.x!=0])
180        else:                             
181            self.qmin= min (self.x)
182        ## Max Q-value
183        self.qmax = max (self.x)
184       
185        # Range used for input to smearing
186        self._qmin_unsmeared = self.qmin
187        self._qmax_unsmeared = self.qmax
188        # Identify the bin range for the unsmeared and smeared spaces
189        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
190        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
191 
192       
193       
194    def setFitRange(self,qmin=None,qmax=None):
195        """ to set the fit range"""
196        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
197        #ToDo: Fix this.
198        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
199            self.qmin = min(self.x[self.x!=0])
200        elif qmin!=None:                       
201            self.qmin = qmin           
202
203        if qmax !=None:
204            self.qmax = qmax
205           
206        # Determine the range needed in unsmeared-Q to cover
207        # the smeared Q range
208        self._qmin_unsmeared = self.qmin
209        self._qmax_unsmeared = self.qmax   
210       
211        self._first_unsmeared_bin = 0
212        self._last_unsmeared_bin  = len(self.x)-1
213       
214        if self.smearer!=None:
215            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
216            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
217            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
218           
219        # Identify the bin range for the unsmeared and smeared spaces
220        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
221        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
222 
223       
224    def getFitRange(self):
225        """
226        return the range of data.x to fit
227        """
228        return self.qmin, self.qmax
229       
230    def residuals(self, fn):
231        """
232        Compute residuals.
233       
234        If self.smearer has been set, use if to smear
235        the data before computing chi squared.
236       
237        :param fn: function that return model value
238       
239        :return: residuals
240       
241        """
242        # Compute theory data f(x)
243        fx= numpy.zeros(len(self.x))
244        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
245       
246        ## Smear theory data
247        if self.smearer is not None:
248            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
249       
250        ## Sanity check
251        if numpy.size(self.dy)!= numpy.size(fx):
252            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
253                                                                              numpy.size(fx))
254                                                                             
255        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
256     
257    def residuals_deriv(self, model, pars=[]):
258        """
259        :return: residuals derivatives .
260       
261        :note: in this case just return empty array
262       
263        """
264        return []
265   
266   
267class FitData2D(Data2D):
268    """ Wrapper class  for SANS data """
269    def __init__(self,sans_data2d ,data=None, err_data=None):
270        Data2D.__init__(self, data= data, err_data= err_data)
271        """
272        Data can be initital with a data (sans plottable)
273        or with vectors.
274        """
275        self.res_err_image=[]
276        self.index_model=[]
277        self.qmin= None
278        self.qmax= None
279        self.smearer = None
280        self.set_data(sans_data2d )
281
282       
283    def set_data(self, sans_data2d, qmin=None, qmax=None ):
284        """
285        Determine the correct qx_data and qy_data within range to fit
286        """
287        self.data     = sans_data2d.data
288        self.err_data = sans_data2d.err_data
289        self.qx_data = sans_data2d.qx_data
290        self.qy_data = sans_data2d.qy_data
291        self.mask       = sans_data2d.mask
292
293        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
294        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
295       
296        ## fitting range
297        if qmin == None:
298            self.qmin = 1e-16
299        if qmax == None:
300            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
301        ## new error image for fitting purpose
302        if self.err_data== None or self.err_data ==[]:
303            self.res_err_data= numpy.ones(len(self.data))
304        else:
305            self.res_err_data = copy.deepcopy(self.err_data)
306        #self.res_err_data[self.res_err_data==0]=1
307       
308        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
309       
310        # Note: mask = True: for MASK while mask = False for NOT to mask
311        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
312        self.index_model = (self.index_model) & (self.mask)
313        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
314       
315    def set_smearer(self,smearer): 
316        """
317        Set smearer
318        """
319        if smearer == None:
320            return
321        self.smearer = smearer
322        self.smearer.set_index(self.index_model)
323        self.smearer.get_data()
324
325    def setFitRange(self,qmin=None,qmax=None):
326        """ to set the fit range"""
327        if qmin==0.0:
328            self.qmin = 1e-16
329        elif qmin!=None:                       
330            self.qmin = qmin           
331        if qmax!=None:
332            self.qmax= qmax       
333        self.radius= numpy.sqrt(self.qx_data**2 + self.qy_data**2)
334        self.index_model = ((self.qmin <= self.radius)&(self.radius<= self.qmax))
335        self.index_model = (self.index_model) &(self.mask)
336        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
337        self.index_model = (self.index_model) & (self.res_err_data!=0)
338       
339    def getFitRange(self):
340        """
341        return the range of data.x to fit
342        """
343        return self.qmin, self.qmax
344     
345    def residuals(self, fn): 
346        """
347        return the residuals
348        """ 
349        if self.smearer != None:
350            fn.set_index(self.index_model)
351            # Get necessary data from self.data and set the data for smearing
352            fn.get_data()
353
354            gn = fn.get_value() 
355        else:
356            gn = fn([self.qx_data[self.index_model],self.qy_data[self.index_model]])
357        # use only the data point within ROI range
358        res=(self.data[self.index_model] - gn)/self.res_err_data[self.index_model]
359        return res
360       
361    def residuals_deriv(self, model, pars=[]):
362        """
363        :return: residuals derivatives .
364       
365        :note: in this case just return empty array
366       
367        """
368        return []
369   
370class FitAbort(Exception):
371    """
372    Exception raise to stop the fit
373    """
374    #print"Creating fit abort Exception"
375
376
377class SansAssembly:
378    """
379    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
380    """
381    def __init__(self, paramlist, model=None , data=None, fitresult=None,
382                 handler=None, curr_thread=None):
383        """
384        :param Model: the model wrapper fro sans -model
385        :param Data: the data wrapper for sans data
386       
387        """
388        self.model = model
389        self.data  = data
390        self.paramlist = paramlist
391        self.curr_thread = curr_thread
392        self.handler = handler
393        self.fitresult = fitresult
394        self.res = []
395        self.func_name = "Functor"
396       
397    def chisq(self, params):
398        """
399        Calculates chi^2
400       
401        :param params: list of parameter values
402       
403        :return: chi^2
404       
405        """
406        sum = 0
407        for item in self.res:
408            sum += item*item
409        if len(self.res)==0:
410            return None
411        return sum/ len(self.res)
412   
413    def __call__(self,params):
414        """
415        Compute residuals
416       
417        :param params: value of parameters to fit
418       
419        """
420        #import thread
421        self.model.setParams(self.paramlist,params)
422        self.res= self.data.residuals(self.model.eval)
423        if self.fitresult is not None and  self.handler is not None:
424            self.fitresult.set_model(model=self.model)
425            fitness = self.chisq(params=params)
426            self.fitresult.set_fitness(fitness=fitness)
427            self.handler.set_result(result=self.fitresult)
428            self.handler.update_fit()
429       
430        #if self.curr_thread != None :
431        #    try:
432        #        self.curr_thread.isquit()
433        #    except:
434        #        raise FitAbort,"stop leastsqr optimizer"   
435        return self.res
436   
437class FitEngine:
438    def __init__(self):
439        """
440        Base class for scipy and park fit engine
441        """
442        #List of parameter names to fit
443        self.paramList=[]
444        #Dictionnary of fitArrange element (fit problems)
445        self.fitArrangeDict={}
446       
447    def _concatenateData(self, listdata=[]):
448        """ 
449        _concatenateData method concatenates each fields of all data
450        contains ins listdata.
451       
452        :param listdata: list of data
453       
454        :return Data: Data is wrapper class for sans plottable. it is created with all parameters
455            of data concatenanted
456           
457        :raise: if listdata is empty  will return None
458        :raise: if data in listdata don't contain dy field ,will create an error
459            during fitting
460           
461        """
462        #TODO: we have to refactor the way we handle data.
463        # We should move away from plottables and move towards the Data1D objects
464        # defined in DataLoader. Data1D allows data manipulations, which should be
465        # used to concatenate.
466        # In the meantime we should switch off the concatenation.
467        #if len(listdata)>1:
468        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
469        #return listdata[0]
470       
471        if listdata==[]:
472            raise ValueError, " data list missing"
473        else:
474            xtemp=[]
475            ytemp=[]
476            dytemp=[]
477            self.mini=None
478            self.maxi=None
479               
480            for item in listdata:
481                data=item.data
482                mini,maxi=data.getFitRange()
483                if self.mini==None and self.maxi==None:
484                    self.mini=mini
485                    self.maxi=maxi
486                else:
487                    if mini < self.mini:
488                        self.mini=mini
489                    if self.maxi < maxi:
490                        self.maxi=maxi
491                       
492                   
493                for i in range(len(data.x)):
494                    xtemp.append(data.x[i])
495                    ytemp.append(data.y[i])
496                    if data.dy is not None and len(data.dy)==len(data.y):   
497                        dytemp.append(data.dy[i])
498                    else:
499                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
500            data= Data(x=xtemp,y=ytemp,dy=dytemp)
501            data.setFitRange(self.mini, self.maxi)
502            return data
503       
504       
505    def set_model(self, model, Uid, pars=[], constraints=[]):
506        """
507        set a model on a given uid in the fit engine.
508       
509        :param model: sans.models type
510        :param Uid: is the key of the fitArrange dictionary where model is
511                saved as a value
512        :param pars: the list of parameters to fit
513        :param constraints: list of
514            tuple (name of parameter, value of parameters)
515            the value of parameter must be a string to constraint 2 different
516            parameters.
517            Example: 
518            we want to fit 2 model M1 and M2 both have parameters A and B.
519            constraints can be:
520             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
521           
522             
523        :note: pars must contains only name of existing model's parameters
524       
525        """
526        if model == None:
527            raise ValueError, "AbstractFitEngine: Need to set model to fit"
528       
529        new_model= model
530        if not issubclass(model.__class__, Model):
531            new_model= Model(model)
532       
533        if len(constraints)>0:
534            for constraint in constraints:
535                name, value = constraint
536                try:
537                    new_model.parameterset[ str(name)].set( str(value) )
538                except:
539                    msg= "Fit Engine: Error occurs when setting the constraint"
540                    msg += " %s for parameter %s "%(value, name)
541                    raise ValueError, msg
542               
543        if len(pars) >0:
544            temp=[]
545            for item in pars:
546                if item in new_model.model.getParamList():
547                    temp.append(item)
548                    self.paramList.append(item)
549                else:
550                   
551                    msg = "wrong parameter %s used"%str(item)
552                    msg += "to set model %s. Choose"%str(new_model.model.name)
553                    msg += "parameter name within %s"%str(new_model.model.getParamList())
554                    raise ValueError,msg
555             
556            #A fitArrange is already created but contains dList only at Uid
557            if self.fitArrangeDict.has_key(Uid):
558                self.fitArrangeDict[Uid].set_model(new_model)
559                self.fitArrangeDict[Uid].pars= pars
560            else:
561            #no fitArrange object has been create with this Uid
562                fitproblem = FitArrange()
563                fitproblem.set_model(new_model)
564                fitproblem.pars= pars
565                self.fitArrangeDict[Uid] = fitproblem
566               
567        else:
568            raise ValueError, "park_integration:missing parameters"
569   
570    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
571        """
572        Receives plottable, creates a list of data to fit,set data
573        in a FitArrange object and adds that object in a dictionary
574        with key Uid.
575       
576        :param data: data added
577        :param Uid: unique key corresponding to a fitArrange object with data
578       
579        """
580        if data.__class__.__name__=='Data2D':
581            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
582        else:
583            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
584       
585        fitdata.setFitRange(qmin=qmin,qmax=qmax)
586        #A fitArrange is already created but contains model only at Uid
587        if self.fitArrangeDict.has_key(Uid):
588            self.fitArrangeDict[Uid].add_data(fitdata)
589        else:
590        #no fitArrange object has been create with this Uid
591            fitproblem= FitArrange()
592            fitproblem.add_data(fitdata)
593            self.fitArrangeDict[Uid]=fitproblem   
594   
595    def get_model(self,Uid):
596        """
597       
598        :param Uid: Uid is key in the dictionary containing the model to return
599       
600        :return:  a model at this uid or None if no FitArrange element was created
601            with this Uid
602           
603        """
604        if self.fitArrangeDict.has_key(Uid):
605            return self.fitArrangeDict[Uid].get_model()
606        else:
607            return None
608   
609    def remove_Fit_Problem(self,Uid):
610        """remove   fitarrange in Uid"""
611        if self.fitArrangeDict.has_key(Uid):
612            del self.fitArrangeDict[Uid]
613           
614    def select_problem_for_fit(self,Uid,value):
615        """
616        select a couple of model and data at the Uid position in dictionary
617        and set in self.selected value to value
618       
619        :param value: the value to allow fitting.
620                can only have the value one or zero
621               
622        """
623        if self.fitArrangeDict.has_key(Uid):
624             self.fitArrangeDict[Uid].set_to_fit( value)
625             
626    def get_problem_to_fit(self,Uid):
627        """
628        return the self.selected value of the fit problem of Uid
629       
630        :param Uid: the Uid of the problem
631       
632        """
633        if self.fitArrangeDict.has_key(Uid):
634             self.fitArrangeDict[Uid].get_to_fit()
635   
636class FitArrange:
637    def __init__(self):
638        """
639        Class FitArrange contains a set of data for a given model
640        to perform the Fit.FitArrange must contain exactly one model
641        and at least one data for the fit to be performed.
642       
643        model: the model selected by the user
644        Ldata: a list of data what the user wants to fit
645           
646        """
647        self.model = None
648        self.dList =[]
649        self.pars=[]
650        #self.selected  is zero when this fit problem is not schedule to fit
651        #self.selected is 1 when schedule to fit
652        self.selected = 0
653       
654    def set_model(self,model):
655        """
656        set_model save a copy of the model
657       
658        :param model: the model being set
659       
660        """
661        self.model = model
662       
663    def add_data(self,data):
664        """
665        add_data fill a self.dList with data to fit
666       
667        :param data: Data to add in the list 
668       
669        """
670        if not data in self.dList:
671            self.dList.append(data)
672           
673    def get_model(self):
674        """
675       
676        :return: saved model
677       
678        """
679        return self.model   
680     
681    def get_data(self):
682        """
683       
684        :return: list of data dList
685       
686        """
687        #return self.dList
688        return self.dList[0] 
689     
690    def remove_data(self,data):
691        """
692        Remove one element from the list
693       
694        :param data: Data to remove from dList
695       
696        """
697        if data in self.dList:
698            self.dList.remove(data)
699           
700    def set_to_fit (self, value=0):
701        """
702        set self.selected to 0 or 1  for other values raise an exception
703       
704        :param value: integer between 0 or 1
705       
706        """
707        self.selected= value
708       
709    def get_to_fit(self):
710        """
711        return self.selected value
712        """
713        return self.selected
Note: See TracBrowser for help on using the repository browser.