source: sasview/park_integration/AbstractFitEngine.py @ 6d20b46

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6d20b46 was 6d20b46, checked in by Gervaise Alina <gervyh@…>, 15 years ago

working on thread issues [incomplete]

  • Property mode set to 100644
File size: 22.8 KB
Line 
1import logging, sys
2import copy
3from copy import deepcopy
4import park,numpy,math, copy
5from DataLoader.data_info import Data1D
6from DataLoader.data_info import Data2D
7class SansParameter(park.Parameter):
8    """
9        SANS model parameters for use in the PARK fitting service.
10        The parameter attribute value is redirected to the underlying
11        parameter value in the SANS model.
12    """
13    def __init__(self, name, model):
14        """
15            @param name: the name of the model parameter
16            @param model: the sans model to wrap as a park model
17        """
18        self._model, self._name = model,name
19        #set the value for the parameter of the given name
20        self.set(model.getParam(name))
21         
22    def _getvalue(self):
23        """
24            override the _getvalue of park parameter
25            @return value the parameter associates with self.name
26        """
27        return self._model.getParam(self.name)
28   
29    def _setvalue(self,value):
30        """
31            override the _setvalue pf park parameter
32            @param value: the value to set on a given parameter
33        """
34        self._model.setParam(self.name, value)
35       
36    value = property(_getvalue,_setvalue)
37   
38    def _getrange(self):
39        """
40            Override _getrange of park parameter
41            return the range of parameter
42        """
43        #if not  self.name in self._model.getDispParamList():
44        lo,hi = self._model.details[self.name][1:3]
45        if lo is None: lo = -numpy.inf
46        if hi is None: hi = numpy.inf
47        #else:
48            #lo,hi = self._model.details[self.name][1:]
49            #if lo is None: lo = -numpy.inf
50            #if hi is None: hi = numpy.inf
51        if lo >= hi:
52            raise ValueError,"wrong fit range for parameters"
53       
54        return lo,hi
55   
56    def _setrange(self,r):
57        """
58            override _setrange of park parameter
59            @param r: the value of the range to set
60        """
61        self._model.details[self.name][1:3] = r
62    range = property(_getrange,_setrange)
63   
64class Model(park.Model):
65    """
66        PARK wrapper for SANS models.
67    """
68    def __init__(self, sans_model, **kw):
69        """
70            @param sans_model: the sans model to wrap using park interface
71        """
72        park.Model.__init__(self, **kw)
73        self.model = sans_model
74        self.name = sans_model.name
75        #list of parameters names
76        self.sansp = sans_model.getParamList()
77        #list of park parameter
78        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
79        #list of parameterset
80        self.parameterset = park.ParameterSet(sans_model.name,pars=self.parkp)
81        self.pars=[]
82 
83    def clone(self):
84        model = self.model.clone()
85        model.name = self.model.name
86        return Model(model)
87    def getParams(self,fitparams):
88        """
89            return a list of value of paramter to fit
90            @param fitparams: list of paramaters name to fit
91        """
92        list=[]
93        self.pars=[]
94        self.pars=fitparams
95        for item in fitparams:
96            for element in self.parkp:
97                 if element.name ==str(item):
98                     list.append(element.value)
99        return list
100   
101   
102    def setParams(self,paramlist, params):
103        """
104            Set value for parameters to fit
105            @param params: list of value for parameters to fit
106        """
107        try:
108            for i in range(len(self.parkp)):
109                for j in range(len(paramlist)):
110                    if self.parkp[i].name==paramlist[j]:
111                        self.parkp[i].value = params[j]
112                        self.model.setParam(self.parkp[i].name,params[j])
113        except:
114            raise
115 
116    def eval(self,x):
117        """
118            override eval method of park model.
119            @param x: the x value used to compute a function
120        """
121        try:
122                return self.model.evalDistribution(x)
123        except:
124                raise
125
126   
127class FitData1D(Data1D):
128    """
129        Wrapper class  for SANS data
130        FitData1D inherits from DataLoader.data_info.Data1D. Implements
131        a way to get residuals from data.
132    """
133    def __init__(self,x, y,dx= None, dy=None, smearer=None):
134        Data1D.__init__(self, x=numpy.array(x), y=numpy.array(y), dx=dx, dy=dy)
135        """
136            @param smearer: is an object of class QSmearer or SlitSmearer
137            that will smear the theory data (slit smearing or resolution
138            smearing) when set.
139           
140            The proper way to set the smearing object would be to
141            do the following:
142           
143            from DataLoader.qsmearing import smear_selection
144            smearer = smear_selection(some_data)
145            fitdata1d = FitData1D( x= [1,3,..,],
146                                    y= [3,4,..,8],
147                                    dx=None,
148                                    dy=[1,2...], smearer= smearer)
149           
150            Note that some_data _HAS_ to be of class DataLoader.data_info.Data1D
151           
152            Setting it back to None will turn smearing off.
153           
154        """
155       
156        self.smearer = smearer
157        if dy ==None or dy==[]:
158            self.dy= numpy.zeros(len(self.y)) 
159        else:
160            self.dy= numpy.asarray(dy)
161     
162        # For fitting purposes, replace zero errors by 1
163        #TODO: check validity for the rare case where only
164        # a few points have zero errors
165        self.dy[self.dy==0]=1
166       
167        ## Min Q-value
168        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
169        if min (self.x) ==0.0 and self.x[0]==0 and not numpy.isfinite(self.y[0]):
170            self.qmin = min(self.x[self.x!=0])
171        else:                             
172            self.qmin= min (self.x)
173        ## Max Q-value
174        self.qmax = max (self.x)
175       
176        # Range used for input to smearing
177        self._qmin_unsmeared = self.qmin
178        self._qmax_unsmeared = self.qmax
179        # Identify the bin range for the unsmeared and smeared spaces
180        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
181        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
182 
183       
184       
185    def setFitRange(self,qmin=None,qmax=None):
186        """ to set the fit range"""
187        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
188        #ToDo: Fix this.
189        if qmin==0.0 and not numpy.isfinite(self.y[qmin]):
190            self.qmin = min(self.x[self.x!=0])
191        elif qmin!=None:                       
192            self.qmin = qmin           
193
194        if qmax !=None:
195            self.qmax = qmax
196           
197        # Determine the range needed in unsmeared-Q to cover
198        # the smeared Q range
199        self._qmin_unsmeared = self.qmin
200        self._qmax_unsmeared = self.qmax   
201       
202        self._first_unsmeared_bin = 0
203        self._last_unsmeared_bin  = len(self.x)-1
204       
205        if self.smearer!=None:
206            self._first_unsmeared_bin, self._last_unsmeared_bin = self.smearer.get_bin_range(self.qmin, self.qmax)
207            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
208            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
209           
210        # Identify the bin range for the unsmeared and smeared spaces
211        self.idx = (self.x>=self.qmin) & (self.x <= self.qmax)
212        self.idx_unsmeared = (self.x>=self._qmin_unsmeared) & (self.x <= self._qmax_unsmeared)
213 
214       
215    def getFitRange(self):
216        """
217            @return the range of data.x to fit
218        """
219        return self.qmin, self.qmax
220       
221    def residuals(self, fn):
222        """
223            Compute residuals.
224           
225            If self.smearer has been set, use if to smear
226            the data before computing chi squared.
227           
228            @param fn: function that return model value
229            @return residuals
230        """
231        # Compute theory data f(x)
232        fx= numpy.zeros(len(self.x))
233        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
234       
235        ## Smear theory data
236        if self.smearer is not None:
237            fx = self.smearer(fx, self._first_unsmeared_bin, self._last_unsmeared_bin)
238       
239        ## Sanity check
240        if numpy.size(self.dy)!= numpy.size(fx):
241            raise RuntimeError, "FitData1D: invalid error array %d <> %d" % (numpy.shape(self.dy),
242                                                                              numpy.size(fx))
243                                                                             
244        return (self.y[self.idx]-fx[self.idx])/self.dy[self.idx]
245     
246 
247       
248    def residuals_deriv(self, model, pars=[]):
249        """
250            @return residuals derivatives .
251            @note: in this case just return empty array
252        """
253        return []
254   
255   
256class FitData2D(Data2D):
257    """ Wrapper class  for SANS data """
258    def __init__(self,sans_data2d ,data=None, err_data=None):
259        Data2D.__init__(self, data= data, err_data= err_data)
260        """
261            Data can be initital with a data (sans plottable)
262            or with vectors.
263        """
264        self.x_bins_array = []
265        self.y_bins_array = []
266        self.res_err_image=[]
267        self.index_model=[]
268        self.qmin= None
269        self.qmax= None
270        self.set_data(sans_data2d )
271       
272       
273    def set_data(self, sans_data2d, qmin=None, qmax=None ):
274        """
275            Determine the correct x_bin and y_bin to fit
276        """
277        self.x_bins_array= numpy.reshape(sans_data2d.x_bins,
278                                         [1,len(sans_data2d.x_bins)])
279        self.y_bins_array = numpy.reshape(sans_data2d.y_bins,
280                                          [len(sans_data2d.y_bins),1])
281       
282        x_max = max(sans_data2d.xmin, sans_data2d.xmax)
283        y_max = max(sans_data2d.ymin, sans_data2d.ymax)
284       
285        ## fitting range
286        if qmin == None:
287            self.qmin = 1e-16
288        if qmax == None:
289            self.qmax = math.sqrt(x_max*x_max +y_max*y_max)
290        ## new error image for fitting purpose
291        if self.err_data== None or self.err_data ==[]:
292            self.res_err_data= numpy.zeros(len(self.y_bins),len(self.x_bins))
293        else:
294            self.res_err_data = copy.deepcopy(self.err_data)
295        self.res_err_data[self.res_err_data==0]=1
296       
297        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
298        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
299       
300       
301    def setFitRange(self,qmin=None,qmax=None):
302        """ to set the fit range"""
303        if qmin==0.0:
304            self.qmin = 1e-16
305        elif qmin!=None:                       
306            self.qmin = qmin           
307        if qmax!=None:
308            self.qmax= qmax
309           
310        self.radius= numpy.sqrt(self.x_bins_array**2 + self.y_bins_array**2)
311        self.index_model = (self.qmin <= self.radius)&(self.radius<= self.qmax)
312       
313    def getFitRange(self):
314        """
315            @return the range of data.x to fit
316        """
317        return self.qmin, self.qmax
318     
319    def residuals(self, fn): 
320       
321        res=self.index_model*(self.data - fn([self.x_bins_array,
322                             self.y_bins_array]))/self.res_err_data
323        return res.ravel() 
324       
325 
326    def residuals_deriv(self, model, pars=[]):
327        """
328            @return residuals derivatives .
329            @note: in this case just return empty array
330        """
331        return []
332   
333class FitAbort(Exception):
334    """
335        Exception raise to stop the fit
336    """
337    print"Creating fit abort Exception"
338
339
340class SansAssembly:
341    """
342         Sans Assembly class a class wrapper to be call in optimizer.leastsq method
343    """
344    def __init__(self,paramlist,model=None , data=None, curr_thread= None):
345        """
346            @param model: the model wrapper for sans -model
347            @param data: the data wrapper for sans data
348        """
349        self.model = model.clone()
350        self.data  = deepcopy(data)
351        self.paramlist = deepcopy(paramlist)
352        self.curr_thread= curr_thread
353        self.res=[]
354        self.func_name="Functor"
355    def chisq(self, params):
356        """
357            Calculates chi^2
358            @param params: list of parameter values
359            @return: chi^2
360        """
361        sum = 0
362        for item in self.res:
363            sum += item*item
364        if len(self.res)==0:
365            return None
366        return sum/ len(self.res)
367   
368    def __call__(self,params):
369        """
370            Compute residuals
371            @param params: value of parameters to fit
372        """
373        #import thread
374        self.model.setParams(self.paramlist,params)
375        self.res= self.data.residuals(self.model.eval)
376        #if self.curr_thread != None :
377        #    try:
378        #        self.curr_thread.isquit()
379        #    except:
380        #        raise FitAbort,"stop leastsqr optimizer"   
381        return self.res
382   
383class FitEngine:
384    def __init__(self):
385        """
386            Base class for scipy and park fit engine
387        """
388        #List of parameter names to fit
389        self.paramList=[]
390        #Dictionnary of fitArrange element (fit problems)
391        self.fitArrangeDict={}
392       
393    def _concatenateData(self, listdata=[]):
394        """ 
395            _concatenateData method concatenates each fields of all data contains ins listdata.
396            @param listdata: list of data
397            @return Data: Data is wrapper class for sans plottable. it is created with all parameters
398             of data concatenanted
399            @raise: if listdata is empty  will return None
400            @raise: if data in listdata don't contain dy field ,will create an error
401            during fitting
402        """
403        #TODO: we have to refactor the way we handle data.
404        # We should move away from plottables and move towards the Data1D objects
405        # defined in DataLoader. Data1D allows data manipulations, which should be
406        # used to concatenate.
407        # In the meantime we should switch off the concatenation.
408        #if len(listdata)>1:
409        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
410        #return listdata[0]
411       
412        if listdata==[]:
413            raise ValueError, " data list missing"
414        else:
415            xtemp=[]
416            ytemp=[]
417            dytemp=[]
418            self.mini=None
419            self.maxi=None
420               
421            for item in listdata:
422                data=item.data
423                mini,maxi=data.getFitRange()
424                if self.mini==None and self.maxi==None:
425                    self.mini=mini
426                    self.maxi=maxi
427                else:
428                    if mini < self.mini:
429                        self.mini=mini
430                    if self.maxi < maxi:
431                        self.maxi=maxi
432                       
433                   
434                for i in range(len(data.x)):
435                    xtemp.append(data.x[i])
436                    ytemp.append(data.y[i])
437                    if data.dy is not None and len(data.dy)==len(data.y):   
438                        dytemp.append(data.dy[i])
439                    else:
440                        raise RuntimeError, "Fit._concatenateData: y-errors missing"
441            data= Data(x=xtemp,y=ytemp,dy=dytemp)
442            data.setFitRange(self.mini, self.maxi)
443            return data
444       
445       
446    def set_model(self,model,Uid,pars=[], constraints=[]):
447        """
448            set a model on a given uid in the fit engine.
449            @param model: sans.models type
450            @param Uid :is the key of the fitArrange dictionnary where model is saved as a value
451            @param pars: the list of parameters to fit
452            @param constraints: list of
453                tuple (name of parameter, value of parameters)
454                the value of parameter must be a string to constraint 2 different
455                parameters.
456                Example:
457                we want to fit 2 model M1 and M2 both have parameters A and B.
458                constraints can be:
459                 constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
460            @note : pars must contains only name of existing model's paramaters
461        """
462        if model == None:
463            raise ValueError, "AbstractFitEngine: Need to set model to fit"
464        #print "set_model",model.name
465        #new_model= deepcopy(model)
466        new_model = model.clone()
467        new_model.name = model.name
468        if not issubclass(new_model.__class__, Model):
469            new_model= Model(new_model)
470       
471        if len(constraints)>0:
472            for constraint in constraints:
473                name, value = constraint
474                try:
475                    new_model.parameterset[ str(name)].set( str(value) )
476                except:
477                    msg= "Fit Engine: Error occurs when setting the constraint"
478                    msg += " %s for parameter %s "%(value, name)
479                    raise ValueError, msg
480               
481        if len(pars) >0:
482            temp=[]
483            for item in pars:
484                if item in new_model.model.getParamList():
485                    temp.append(item)
486                    self.paramList.append(item)
487                else:
488                   
489                    msg = "wrong parameter %s used"%str(item)
490                    msg += "to set model %s. Choose"%str(new_model.model.name)
491                    msg += "parameter name within %s"%str(new_model.model.getParamList())
492                    raise ValueError,msg
493             
494            #A fitArrange is already created but contains dList only at Uid
495            if self.fitArrangeDict.has_key(Uid):
496                self.fitArrangeDict[Uid].set_model(new_model)
497                self.fitArrangeDict[Uid].pars= pars
498            else:
499            #no fitArrange object has been create with this Uid
500                fitproblem = FitArrange()
501                fitproblem.set_model(new_model)
502                fitproblem.pars= pars
503                self.fitArrangeDict[Uid] = fitproblem
504               
505        else:
506            raise ValueError, "park_integration:missing parameters"
507   
508    def set_data(self,data,Uid,smearer=None,qmin=None,qmax=None):
509        """ Receives plottable, creates a list of data to fit,set data
510            in a FitArrange object and adds that object in a dictionary
511            with key Uid.
512            @param data: data added
513            @param Uid: unique key corresponding to a fitArrange object with data
514        """
515        #print "set_data",data.__class__.__name__
516        if data.__class__.__name__=='Data2D':
517            fitdata=FitData2D(sans_data2d=data, data=data.data, err_data= data.err_data)
518        else:
519            fitdata=FitData1D(x=data.x, y=data.y , dx= data.dx,dy=data.dy,smearer=smearer)
520       
521        fitdata.setFitRange(qmin=qmin,qmax=qmax)
522        #A fitArrange is already created but contains model only at Uid
523        if self.fitArrangeDict.has_key(Uid):
524            self.fitArrangeDict[Uid].add_data(fitdata)
525        else:
526        #no fitArrange object has been create with this Uid
527            fitproblem= FitArrange()
528            fitproblem.add_data(fitdata)
529            self.fitArrangeDict[Uid]=fitproblem   
530   
531    def get_model(self,Uid):
532        """
533            @param Uid: Uid is key in the dictionary containing the model to return
534            @return  a model at this uid or None if no FitArrange element was created
535            with this Uid
536        """
537        if self.fitArrangeDict.has_key(Uid):
538            return self.fitArrangeDict[Uid].get_model()
539        else:
540            return None
541   
542    def remove_Fit_Problem(self,Uid):
543        """remove   fitarrange in Uid"""
544        if self.fitArrangeDict.has_key(Uid):
545            del self.fitArrangeDict[Uid]
546           
547    def select_problem_for_fit(self,Uid,value):
548        """
549            select a couple of model and data at the Uid position in dictionary
550            and set in self.selected value to value
551            @param value: the value to allow fitting. can only have the value one or zero
552        """
553        if self.fitArrangeDict.has_key(Uid):
554             self.fitArrangeDict[Uid].set_to_fit( value)
555             
556             
557    def get_problem_to_fit(self,Uid):
558        """
559            return the self.selected value of the fit problem of Uid
560           @param Uid: the Uid of the problem
561        """
562        if self.fitArrangeDict.has_key(Uid):
563             self.fitArrangeDict[Uid].get_to_fit()
564   
565class FitArrange:
566    def __init__(self):
567        """
568            Class FitArrange contains a set of data for a given model
569            to perform the Fit.FitArrange must contain exactly one model
570            and at least one data for the fit to be performed.
571            model: the model selected by the user
572            Ldata: a list of data what the user wants to fit
573           
574        """
575        self.model = None
576        self.dList =[]
577        self.pars=[]
578        #self.selected  is zero when this fit problem is not schedule to fit
579        #self.selected is 1 when schedule to fit
580        self.selected = 0
581       
582    def set_model(self,model):
583        """
584            set_model save a copy of the model
585            @param model: the model being set
586        """
587        self.model = model
588        self.model.name = model.name
589       
590    def add_data(self,data):
591        """
592            add_data fill a self.dList with data to fit
593            @param data: Data to add in the list 
594        """
595        if not data in self.dList:
596            self.dList.append(data)
597           
598    def get_model(self):
599        """ @return: saved model """
600        return self.model   
601     
602    def get_data(self):
603        """ @return:  list of data dList"""
604        #return self.dList
605        return self.dList[0] 
606     
607    def remove_data(self,data):
608        """
609            Remove one element from the list
610            @param data: Data to remove from dList
611        """
612        if data in self.dList:
613            self.dList.remove(data)
614    def set_to_fit (self, value=0):
615        """
616           set self.selected to 0 or 1  for other values raise an exception
617           @param value: integer between 0 or 1
618        """
619        self.selected= value
620       
621    def get_to_fit(self):
622        """
623            @return self.selected value
624        """
625        return self.selected
Note: See TracBrowser for help on using the repository browser.