source: sasview/park_integration/AbstractFitEngine.py @ b341b16

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b341b16 was 72c7d31, checked in by Mathieu Doucet <doucetm@…>, 15 years ago

park_integration: modified residuals to allow for partial q range when using smeared data

  • Property mode set to 100644
File size: 22.9 KB
Line 
1import logging, sys
2import park,numpy,math, copy
3
4class SansParameter(park.Parameter):
5    """
6        SANS model parameters for use in the PARK fitting service.
7        The parameter attribute value is redirected to the underlying
8        parameter value in the SANS model.
9    """
10    def __init__(self, name, model):
11        """
12            @param name: the name of the model parameter
13            @param model: the sans model to wrap as a park model
14        """
15        self._model, self._name = model,name
16        #set the value for the parameter of the given name
17        self.set(model.getParam(name))
18         
19    def _getvalue(self):
20        """
21            override the _getvalue of park parameter
22            @return value the parameter associates with self.name
23        """
24        return self._model.getParam(self.name)
25   
26    def _setvalue(self,value):
27        """
28            override the _setvalue pf park parameter
29            @param value: the value to set on a given parameter
30        """
31        self._model.setParam(self.name, value)
32       
33    value = property(_getvalue,_setvalue)
34   
35    def _getrange(self):
36        """
37            Override _getrange of park parameter
38            return the range of parameter
39        """
40        if not  self.name in self._model.getDispParamList():
41            lo,hi = self._model.details[self.name][1:]
42            if lo is None: lo = -numpy.inf
43            if hi is None: hi = numpy.inf
44        else:
45            lo= -numpy.inf
46            hi= numpy.inf
47        if lo >= hi:
48            raise ValueError,"wrong fit range for parameters"
49       
50        return lo,hi
51   
52    def _setrange(self,r):
53        """
54            override _setrange of park parameter
55            @param r: the value of the range to set
56        """
57        self._model.details[self.name][1:] = r
58    range = property(_getrange,_setrange)
59   
60class Model(park.Model):
61    """
62        PARK wrapper for SANS models.
63    """
64    def __init__(self, sans_model, **kw):
65        """
66            @param sans_model: the sans model to wrap using park interface
67        """
68        park.Model.__init__(self, **kw)
69        self.model = sans_model
70        self.name = sans_model.name
71        #list of parameters names
72        self.sansp = sans_model.getParamList()
73        #list of park parameter
74        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
75        #list of parameterset
76        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
77        self.pars=[]
78 
79 
80    def getParams(self,fitparams):
81        """
82            return a list of value of paramter to fit
83            @param fitparams: list of paramaters name to fit
84        """
85        list=[]
86        self.pars=[]
87        self.pars=fitparams
88        for item in fitparams:
89            for element in self.parkp:
90                 if element.name ==str(item):
91                     list.append(element.value)
92        return list
93   
94   
95    def setParams(self,paramlist, params):
96        """
97            Set value for parameters to fit
98            @param params: list of value for parameters to fit
99        """
100        try:
101            for i in range(len(self.parkp)):
102                for j in range(len(paramlist)):
103                    if self.parkp[i].name==paramlist[j]:
104                        self.parkp[i].value = params[j]
105                        self.model.setParam(self.parkp[i].name,params[j])
106        except:
107            raise
108 
109    def eval(self,x):
110        """
111            override eval method of park model.
112            @param x: the x value used to compute a function
113        """
114        return self.model.runXY(x)
115   
116   
117
118
119class Data(object):
120    """ Wrapper class  for SANS data """
121    def __init__(self,x=None,y=None,dy=None,dx=None,sans_data=None):
122        """
123            Data can be initital with a data (sans plottable)
124            or with vectors.
125        """
126        if  sans_data !=None:
127            self.x= sans_data.x
128            self.y= sans_data.y
129            self.dx= sans_data.dx
130            self.dy= sans_data.dy
131           
132        elif (x!=None and y!=None and dy!=None):
133                self.x=x
134                self.y=y
135                self.dx=dx
136                self.dy=dy
137        else:
138            raise ValueError,\
139            "Data is missing x, y or dy, impossible to compute residuals later on"
140        self.qmin=None
141        self.qmax=None
142       
143       
144    def setFitRange(self,mini=None,maxi=None):
145        """ to set the fit range"""
146       
147        self.qmin=mini           
148        self.qmax=maxi
149       
150       
151    def getFitRange(self):
152        """
153            @return the range of data.x to fit
154        """
155        return self.qmin, self.qmax
156     
157     
158    def residuals(self, fn):
159        """ @param fn: function that return model value
160            @return residuals
161        """
162        x,y,dy = [numpy.asarray(v) for v in (self.x,self.y,self.dy)]
163        if self.qmin==None and self.qmax==None: 
164            fx =numpy.asarray([fn(v) for v in x])
165            return (y - fx)/dy
166        else:
167            idx = (x>=self.qmin) & (x <= self.qmax)
168            fx = numpy.asarray([fn(item)for item in x[idx ]])
169            return (y[idx] - fx)/dy[idx]
170       
171    def residuals_deriv(self, model, pars=[]):
172        """
173            @return residuals derivatives .
174            @note: in this case just return empty array
175        """
176        return []
177   
178   
179class FitData1D(object):
180    """ Wrapper class  for SANS data """
181    def __init__(self,sans_data1d, smearer=None):
182        """
183            Data can be initital with a data (sans plottable)
184            or with vectors.
185           
186            self.smearer is an object of class QSmearer or SlitSmearer
187            that will smear the theory data (slit smearing or resolution
188            smearing) when set.
189           
190            The proper way to set the smearing object would be to
191            do the following:
192           
193            from DataLoader.qsmearing import smear_selection
194            fitdata1d = FitData1D(some_data)
195            fitdata1d.smearer = smear_selection(some_data)
196           
197            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
198           
199            Setting it back to None will turn smearing off.
200           
201        """
202       
203        self.smearer = smearer
204     
205        # Initialize from Data1D object
206        self.data=sans_data1d
207        self.x= sans_data1d.x
208        self.y= sans_data1d.y
209        self.dx= sans_data1d.dx
210        self.dy= sans_data1d.dy
211       
212        ## Min Q-value
213        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
214        if min (self.data.x) ==0.0 and self.data.x[0]==0 and not numpy.isfinite(self.data.y[0]):
215            self.qmin = min(self.data.x[self.data.x!=0])
216        else:                             
217            self.qmin= min (self.data.x)
218        ## Max Q-value
219        self.qmax= max (self.data.x)
220       
221        # Range used for input to smearing
222        self._qmin_unsmeared = self.qmin
223        self._qmax_unsmeared = self.qmax
224       
225       
226    def setFitRange(self,qmin=None,qmax=None):
227        """ to set the fit range"""
228        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
229        #ToDo: Fix this.
230        if qmin==0.0 and not numpy.isfinite(self.data.y[qmin]):
231            self.qmin = min(self.data.x[self.data.x!=0])
232        elif qmin!=None:                       
233            self.qmin = qmin           
234
235        if qmax !=None:
236            self.qmax = qmax
237           
238        # Range used for input to smearing
239        self._qmin_unsmeared = self.qmin
240        self._qmax_unsmeared = self.qmax   
241       
242        # Determine the range needed in unsmeared-Q to cover
243        # the smeared Q range
244        #TODO: use the smearing matrix to determine which
245        # bin range to use
246        if self.smearer.__class__.__name__ == 'SlitSmearer':
247            self._qmin_unsmeared = min(self.data.x)
248            self._qmax_unsmeared = max(self.data.x)
249        elif self.smearer.__class__.__name__ == 'QSmearer':
250            # Take 3 sigmas as the offset between smeared and unsmeared space
251            try:
252                offset = 3.0*max(self.smearer.width)
253                self._qmin_unsmeared = max([min(self.data.x), self.qmin-offset])
254                self._qmax_unsmeared = min([max(self.data.x), self.qmax+offset])
255            except:
256                logging.error("FitData1D.setFitRange: %s" % sys.exc_value)
257       
258       
259    def getFitRange(self):
260        """
261            @return the range of data.x to fit
262        """
263        return self.qmin, self.qmax
264       
265    def residuals(self, fn):
266        """
267            Compute residuals.
268           
269            If self.smearer has been set, use if to smear
270            the data before computing chi squared.
271           
272            @param fn: function that return model value
273            @return residuals
274        """
275        x,y = [numpy.asarray(v) for v in (self.x,self.y)]
276        if self.dy ==None or self.dy==[]:
277            dy= numpy.zeros(len(y)) 
278        else:
279            dy= numpy.asarray(self.dy)
280     
281        # For fitting purposes, replace zero errors by 1
282        #TODO: check validity for the rare case where only
283        # a few points have zero errors
284        dy[dy==0]=1
285       
286        # Identify the bin range for the unsmeared and smeared spaces
287        idx = (x>=self.qmin) & (x <= self.qmax)
288        idx_unsmeared = (x>=self._qmin_unsmeared) & (x <= self._qmax_unsmeared)
289 
290        # Compute theory data f(x)
291        fx= numpy.zeros(len(x))
292   
293        _first_bin = None
294        _last_bin  = None
295        for i_x in range(len(x)):
296            try:
297                if idx_unsmeared[i_x]==True:
298                    # Identify first and last bin
299                    #TODO: refactor this to pass q-values to the smearer
300                    # and let it figure out which bin range to use
301                    if _first_bin is None:
302                        _first_bin = i_x
303                    else:
304                        _last_bin  = i_x
305                   
306                    value = fn(x[i_x])
307                    fx[i_x] = value
308            except:
309                ## skip error for model.run(x)
310                pass
311                 
312        ## Smear theory data
313        if self.smearer is not None:
314            fx = self.smearer(fx, _first_bin, _last_bin)
315       
316        ## Sanity check
317        if numpy.size(dy)!= numpy.size(fx):
318            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.size(dy), numpy.size(fx))
319
320        return (y[idx]-fx[idx])/dy[idx]
321     
322 
323       
324    def residuals_deriv(self, model, pars=[]):
325        """
326            @return residuals derivatives .
327            @note: in this case just return empty array
328        """
329        return []
330   
331   
332class FitData2D(object):
333    """ Wrapper class  for SANS data """
334    def __init__(self,sans_data2d):
335        """
336            Data can be initital with a data (sans plottable)
337            or with vectors.
338        """
339        self.data=sans_data2d
340        self.image = sans_data2d.data
341        self.err_image = sans_data2d.err_data
342        self.x_bins= sans_data2d.x_bins
343        self.y_bins= sans_data2d.y_bins
344       
345        x = max(self.data.xmin, self.data.xmax)
346        y = max(self.data.ymin, self.data.ymax)
347       
348        ## fitting range
349        self.qmin = 1e-16
350        self.qmax = math.sqrt(x*x +y*y)
351        ## new error image for fitting purpose
352        if self.err_image== None or self.err_image ==[]:
353            self.res_err_image= numpy.zeros(len(self.y_bins),len(self.x_bins))
354        else:
355            self.res_err_image = copy.deepcopy(self.err_image)
356        self.res_err_image[self.err_image==0]=1
357       
358       
359    def setFitRange(self,qmin=None,qmax=None):
360        """ to set the fit range"""
361        if qmin==0.0:
362            self.qmin = 1e-16
363        elif qmin!=None:                       
364            self.qmin = qmin           
365        if qmax!=None:
366            self.qmax= qmax
367     
368       
369    def getFitRange(self):
370        """
371            @return the range of data.x to fit
372        """
373        return self.qmin, self.qmax
374     
375     
376    def residuals(self, fn):
377        """ @param fn: function that return model value
378            @return residuals
379        """
380        res=[]
381       
382        for i in range(len(self.x_bins)):
383            for j in range(len(self.y_bins)):
384                temp = math.pow(self.data.x_bins[i],2)+math.pow(self.data.y_bins[j],2)
385                radius= math.sqrt(temp)
386                if self.qmin <= radius and radius <= self.qmax:
387                    res.append( (self.image[j][i]- fn([self.x_bins[i],self.y_bins[j]]))\
388                            /self.res_err_image[j][i] )
389       
390        return numpy.array(res)
391       
392         
393    def residuals_deriv(self, model, pars=[]):
394        """
395            @return residuals derivatives .
396            @note: in this case just return empty array
397        """
398        return []
399   
400class FitAbort(Exception):
401    """
402        Exception raise to stop the fit
403    """
404    print"Creating fit abort Exception"
405
406
407class SansAssembly:
408    """
409         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
410    """
411    def __init__(self,paramlist,Model=None , Data=None, curr_thread= None):
412        """
413            @param Model: the model wrapper fro sans -model
414            @param Data: the data wrapper for sans data
415        """
416        self.model = Model
417        self.data  = Data
418        self.paramlist=paramlist
419        self.curr_thread= curr_thread
420        self.res=[]
421        self.func_name="Functor"
422    def chisq(self, params):
423        """
424            Calculates chi^2
425            @param params: list of parameter values
426            @return: chi^2
427        """
428        sum = 0
429        for item in self.res:
430            sum += item*item
431        if len(self.res)==0:
432            return None
433        return sum/ len(self.res)
434   
435    def __call__(self,params):
436        """
437            Compute residuals
438            @param params: value of parameters to fit
439        """
440        #import thread
441        self.model.setParams(self.paramlist,params)
442        self.res= self.data.residuals(self.model.eval)
443        #if self.curr_thread != None :
444        #    try:
445        #        self.curr_thread.isquit()
446        #    except:
447        #        raise FitAbort,"stop leastsqr optimizer"   
448        return self.res
449   
450class FitEngine:
451    def __init__(self):
452        """
453            Base class for scipy and park fit engine
454        """
455        #List of parameter names to fit
456        self.paramList=[]
457        #Dictionnary of fitArrange element (fit problems)
458        self.fitArrangeDict={}
459       
460    def _concatenateData(self, listdata=[]):
461        """ 
462            _concatenateData method concatenates each fields of all data contains ins listdata.
463            @param listdata: list of data
464            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
465             of data concatenanted
466            @raise: if listdata is empty  will return None
467            @raise: if data in listdata don't contain dy field ,will create an error
468            during fitting
469        """
470        #TODO: we have to refactor the way we handle data.
471        # We should move away from plottables and move towards the Data1D objects
472        # defined in DataLoader. Data1D allows data manipulations, which should be
473        # used to concatenate.
474        # In the meantime we should switch off the concatenation.
475        #if len(listdata)>1:
476        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
477        #return listdata[0]
478       
479        if listdata==[]:
480            raise ValueError, " data list missing"
481        else:
482            xtemp=[]
483            ytemp=[]
484            dytemp=[]
485            self.mini=None
486            self.maxi=None
487               
488            for item in listdata:
489                data=item.data
490                mini,maxi=data.getFitRange()
491                if self.mini==None and self.maxi==None:
492                    self.mini=mini
493                    self.maxi=maxi
494                else:
495                    if mini < self.mini:
496                        self.mini=mini
497                    if self.maxi < maxi:
498                        self.maxi=maxi
499                       
500                   
501                for i in range(len(data.x)):
502                    xtemp.append(data.x[i])
503                    ytemp.append(data.y[i])
504                    if data.dy is not None and len(data.dy)==len(data.y):   
505                        dytemp.append(data.dy[i])
506                    else:
507                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
508            data= Data(x=xtemp,y=ytemp,dy=dytemp)
509            data.setFitRange(self.mini, self.maxi)
510            return data
511       
512       
513    def set_model(self,model,Uid,pars=[]):
514        """
515            set a model on a given uid in the fit engine.
516            @param model: the model to fit
517            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
518            @param pars: the list of parameters to fit
519            @note : pars must contains only name of existing model's paramaters
520        """
521        if len(pars) >0:
522            if model==None:
523                raise ValueError, "AbstractFitEngine: Specify parameters to fit"
524            else:
525                temp=[]
526                for item in pars:
527                    if item in model.model.getParamList():
528                        temp.append(item)
529                        self.paramList.append(item)
530                    else:
531                        raise ValueError,"wrong paramter %s used to set model %s. Choose\
532                            parameter name within %s"%(item, model.model.name,str(model.model.getParamList()))
533                        return
534            #A fitArrange is already created but contains dList only at Uid
535            if self.fitArrangeDict.has_key(Uid):
536                self.fitArrangeDict[Uid].set_model(model)
537                self.fitArrangeDict[Uid].pars= pars
538            else:
539            #no fitArrange object has been create with this Uid
540                fitproblem = FitArrange()
541                fitproblem.set_model(model)
542                fitproblem.pars= pars
543                self.fitArrangeDict[Uid] = fitproblem
544               
545        else:
546            raise ValueError, "park_integration:missing parameters"
547   
548    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
549        """ Receives plottable, creates a list of data to fit,set data
550            in a FitArrange object and adds that object in a dictionary
551            with key Uid.
552            @param data: data added
553            @param Uid: unique key corresponding to a fitArrange object with data
554        """
555        if data.__class__.__name__=='Data2D':
556            fitdata=FitData2D(data)
557        else:
558            fitdata=FitData1D(data, smearer)
559       
560        fitdata.setFitRange(qmin=qmin,qmax=qmax)
561        #A fitArrange is already created but contains model only at Uid
562        if self.fitArrangeDict.has_key(Uid):
563            self.fitArrangeDict[Uid].add_data(fitdata)
564        else:
565        #no fitArrange object has been create with this Uid
566            fitproblem= FitArrange()
567            fitproblem.add_data(fitdata)
568            self.fitArrangeDict[Uid]=fitproblem   
569   
570    def get_model(self,Uid):
571        """
572            @param Uid: Uid is key in the dictionary containing the model to return
573            @return  a model at this uid or None if no FitArrange element was created
574            with this Uid
575        """
576        if self.fitArrangeDict.has_key(Uid):
577            return self.fitArrangeDict[Uid].get_model()
578        else:
579            return None
580   
581    def remove_Fit_Problem(self,Uid):
582        """remove   fitarrange in Uid"""
583        if self.fitArrangeDict.has_key(Uid):
584            del self.fitArrangeDict[Uid]
585           
586    def select_problem_for_fit(self,Uid,value):
587        """
588            select a couple of model and data at the Uid position in dictionary
589            and set in self.selected value to value
590            @param value: the value to allow fitting. can only have the value one or zero
591        """
592        if self.fitArrangeDict.has_key(Uid):
593             self.fitArrangeDict[Uid].set_to_fit( value)
594             
595             
596    def get_problem_to_fit(self,Uid):
597        """
598            return the self.selected value of the fit problem of Uid
599           @param Uid: the Uid of the problem
600        """
601        if self.fitArrangeDict.has_key(Uid):
602             self.fitArrangeDict[Uid].get_to_fit()
603   
604class FitArrange:
605    def __init__(self):
606        """
607            Class FitArrange contains a set of data for a given model
608            to perform the Fit.FitArrange must contain exactly one model
609            and at least one data for the fit to be performed.
610            model: the model selected by the user
611            Ldata: a list of data what the user wants to fit
612           
613        """
614        self.model = None
615        self.dList =[]
616        self.pars=[]
617        #self.selected  is zero when this fit problem is not schedule to fit
618        #self.selected is 1 when schedule to fit
619        self.selected = 0
620       
621    def set_model(self,model):
622        """
623            set_model save a copy of the model
624            @param model: the model being set
625        """
626        self.model = model
627       
628    def add_data(self,data):
629        """
630            add_data fill a self.dList with data to fit
631            @param data: Data to add in the list 
632        """
633        if not data in self.dList:
634            self.dList.append(data)
635           
636    def get_model(self):
637        """ @return: saved model """
638        return self.model   
639     
640    def get_data(self):
641        """ @return:  list of data dList"""
642        #return self.dList
643        return self.dList[0] 
644     
645    def remove_data(self,data):
646        """
647            Remove one element from the list
648            @param data: Data to remove from dList
649        """
650        if data in self.dList:
651            self.dList.remove(data)
652    def set_to_fit (self, value=0):
653        """
654           set self.selected to 0 or 1  for other values raise an exception
655           @param value: integer between 0 or 1
656        """
657        self.selected= value
658       
659    def get_to_fit(self):
660        """
661            @return self.selected value
662        """
663        return self.selected
Note: See TracBrowser for help on using the repository browser.