source: sasview/park_integration/AbstractFitEngine.py @ 85bb870

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 85bb870 was 89f3b66, checked in by Gervaise Alina <gervyh@…>, 14 years ago

working pylint

  • Property mode set to 100644
File size: 24.5 KB
Line 
1
2import  copy
3import logging
4import sys
5import numpy
6import math
7import park
8from DataLoader.data_info import Data1D
9from DataLoader.data_info import Data2D
10
11class SansParameter(park.Parameter):
12    """
13    SANS model parameters for use in the PARK fitting service.
14    The parameter attribute value is redirected to the underlying
15    parameter value in the SANS model.
16    """
17    def __init__(self, name, model):
18        """
19        :param name: the name of the model parameter
20        :param model: the sans model to wrap as a park model
21       
22        """
23        self._model, self._name = model, name
24        #set the value for the parameter of the given name
25        self.set(model.getParam(name))
26         
27    def _getvalue(self):
28        """
29        override the _getvalue of park parameter
30       
31        :return value the parameter associates with self.name
32       
33        """
34        return self._model.getParam(self.name)
35   
36    def _setvalue(self, value):
37        """
38        override the _setvalue pf park parameter
39       
40        :param value: the value to set on a given parameter
41       
42        """
43        self._model.setParam(self.name, value)
44       
45    value = property(_getvalue,_setvalue)
46   
47    def _getrange(self):
48        """
49        Override _getrange of park parameter
50        return the range of parameter
51        """
52        #if not  self.name in self._model.getDispParamList():
53        lo, hi = self._model.details[self.name][1:3]
54        if lo is None: lo = -numpy.inf
55        if hi is None: hi = numpy.inf
56        #else:
57            #lo,hi = self._model.details[self.name][1:]
58            #if lo is None: lo = -numpy.inf
59            #if hi is None: hi = numpy.inf
60        if lo >= hi:
61            raise ValueError,"wrong fit range for parameters"
62       
63        return lo, hi
64   
65    def _setrange(self, r):
66        """
67        override _setrange of park parameter
68       
69        :param r: the value of the range to set
70       
71        """
72        self._model.details[self.name][1:3] = r
73    range = property(_getrange, _setrange)
74   
75class Model(park.Model):
76    """
77    PARK wrapper for SANS models.
78    """
79    def __init__(self, sans_model, **kw):
80        """
81        :param sans_model: the sans model to wrap using park interface
82       
83        """
84        park.Model.__init__(self, **kw)
85        self.model = sans_model
86        self.name = sans_model.name
87        #list of parameters names
88        self.sansp = sans_model.getParamList()
89        #list of park parameter
90        self.parkp = [SansParameter(p,sans_model) for p in self.sansp]
91        #list of parameterset
92        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
93        self.pars = []
94 
95    def getParams(self, fitparams):
96        """
97        return a list of value of paramter to fit
98       
99        :param fitparams: list of paramaters name to fit
100       
101        """
102        list = []
103        self.pars = []
104        self.pars = fitparams
105        for item in fitparams:
106            for element in self.parkp:
107                 if element.name == str(item):
108                     list.append(element.value)
109        return list
110   
111    def setParams(self, paramlist, params):
112        """
113        Set value for parameters to fit
114       
115        :param params: list of value for parameters to fit
116       
117        """
118        try:
119            for i in range(len(self.parkp)):
120                for j in range(len(paramlist)):
121                    if self.parkp[i].name == paramlist[j]:
122                        self.parkp[i].value = params[j]
123                        self.model.setParam(self.parkp[i].name, params[j])
124        except:
125            raise
126 
127    def eval(self, x):
128        """
129        override eval method of park model.
130       
131        :param x: the x value used to compute a function
132       
133        """
134        try:
135            return self.model.evalDistribution(x)
136        except:
137            raise
138
139   
140class FitData1D(Data1D):
141    """
142    Wrapper class  for SANS data
143    FitData1D inherits from DataLoader.data_info.Data1D. Implements
144    a way to get residuals from data.
145    """
146    def __init__(self, x, y, dx=None, dy=None, smearer=None):
147        """
148        :param smearer: is an object of class QSmearer or SlitSmearer
149           that will smear the theory data (slit smearing or resolution
150           smearing) when set.
151       
152        The proper way to set the smearing object would be to
153        do the following: ::
154       
155            from DataLoader.qsmearing import smear_selection
156            smearer = smear_selection(some_data)
157            fitdata1d = FitData1D( x= [1,3,..,],
158                                    y= [3,4,..,8],
159                                    dx=None,
160                                    dy=[1,2...], smearer= smearer)
161       
162        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
163            Setting it back to None will turn smearing off.
164           
165        """
166        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
167       
168        self.smearer = smearer
169       
170        # Check error bar; if no error bar found, set it constant(=1)
171        # TODO: Should provide an option for users to set it like percent, constant, or dy data
172        if dy == None or dy == [] or dy.all() == 0:
173            self.dy = numpy.ones(len(y)) 
174        else:
175            self.dy = numpy.asarray(dy).copy()
176
177        ## Min Q-value
178        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
179        if min (self.x) == 0.0 and self.x[0] == 0 and\
180                     not numpy.isfinite(self.y[0]):
181            self.qmin = min(self.x[self.x!=0])
182        else:                             
183            self.qmin = min(self.x)
184        ## Max Q-value
185        self.qmax = max(self.x)
186       
187        # Range used for input to smearing
188        self._qmin_unsmeared = self.qmin
189        self._qmax_unsmeared = self.qmax
190        # Identify the bin range for the unsmeared and smeared spaces
191        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
192        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
193                            & (self.x <= self._qmax_unsmeared)
194 
195    def setFitRange(self, qmin=None, qmax=None):
196        """ to set the fit range"""
197        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
198        # ToDo: Find better way to do it.
199        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
200            self.qmin = min(self.x[self.x != 0])
201        elif qmin != None:                       
202            self.qmin = qmin           
203        if qmax != None:
204            self.qmax = qmax
205        # Determine the range needed in unsmeared-Q to cover
206        # the smeared Q range
207        self._qmin_unsmeared = self.qmin
208        self._qmax_unsmeared = self.qmax   
209       
210        self._first_unsmeared_bin = 0
211        self._last_unsmeared_bin  = len(self.x) - 1
212       
213        if self.smearer!=None:
214            self._first_unsmeared_bin, self._last_unsmeared_bin = \
215                    self.smearer.get_bin_range(self.qmin, self.qmax)
216            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
217            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
218           
219        # Identify the bin range for the unsmeared and smeared spaces
220        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
221        ## zero error can not participate for fitting
222        self.idx = self.idx & (self.dy != 0) 
223        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
224                            & (self.x <= self._qmax_unsmeared)
225 
226    def getFitRange(self):
227        """
228        return the range of data.x to fit
229        """
230        return self.qmin, self.qmax
231       
232    def residuals(self, fn):
233        """
234        Compute residuals.
235       
236        If self.smearer has been set, use if to smear
237        the data before computing chi squared.
238       
239        :param fn: function that return model value
240       
241        :return: residuals
242       
243        """
244        # Compute theory data f(x)
245        fx = numpy.zeros(len(self.x))
246        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
247       
248        ## Smear theory data
249        if self.smearer is not None:
250            fx = self.smearer(fx, self._first_unsmeared_bin, 
251                              self._last_unsmeared_bin)
252        ## Sanity check
253        if numpy.size(self.dy) != numpy.size(fx):
254            msg = "FitData1D: invalid error array "
255            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))                                                     
256            raise RuntimeError, msg 
257        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx]
258     
259    def residuals_deriv(self, model, pars=[]):
260        """
261        :return: residuals derivatives .
262       
263        :note: in this case just return empty array
264       
265        """
266        return []
267   
268class FitData2D(Data2D):
269    """ Wrapper class  for SANS data """
270    def __init__(self, sans_data2d, data=None, err_data=None):
271        Data2D.__init__(self, data=data, err_data= err_data)
272        """
273        Data can be initital with a data (sans plottable)
274        or with vectors.
275        """
276        self.res_err_image = []
277        self.index_model = []
278        self.qmin = None
279        self.qmax = None
280        self.smearer = None
281        self.set_data(sans_data2d)
282
283       
284    def set_data(self, sans_data2d, qmin=None, qmax=None):
285        """
286        Determine the correct qx_data and qy_data within range to fit
287        """
288        self.data = sans_data2d.data
289        self.err_data = sans_data2d.err_data
290        self.qx_data = sans_data2d.qx_data
291        self.qy_data = sans_data2d.qy_data
292        self.mask = sans_data2d.mask
293
294        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
295        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
296       
297        ## fitting range
298        if qmin == None:
299            self.qmin = 1e-16
300        if qmax == None:
301            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
302        ## new error image for fitting purpose
303        if self.err_data == None or self.err_data == []:
304            self.res_err_data = numpy.ones(len(self.data))
305        else:
306            self.res_err_data = copy.deepcopy(self.err_data)
307        #self.res_err_data[self.res_err_data==0]=1
308       
309        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
310       
311        # Note: mask = True: for MASK while mask = False for NOT to mask
312        self.index_model = ((self.qmin <= self.radius)&\
313                            (self.radius <= self.qmax))
314        self.index_model = (self.index_model) & (self.mask)
315        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
316       
317    def set_smearer(self,smearer): 
318        """
319        Set smearer
320        """
321        if smearer == None:
322            return
323        self.smearer = smearer
324        self.smearer.set_index(self.index_model)
325        self.smearer.get_data()
326
327    def setFitRange(self, qmin=None, qmax=None):
328        """ to set the fit range"""
329        if qmin == 0.0:
330            self.qmin = 1e-16
331        elif qmin != None:                       
332            self.qmin = qmin           
333        if qmax != None:
334            self.qmax = qmax       
335        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
336        self.index_model = ((self.qmin <= self.radius)&\
337                            (self.radius <= self.qmax))
338        self.index_model = (self.index_model) &(self.mask)
339        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
340        self.index_model = (self.index_model) & (self.res_err_data!=0)
341       
342    def getFitRange(self):
343        """
344        return the range of data.x to fit
345        """
346        return self.qmin, self.qmax
347     
348    def residuals(self, fn): 
349        """
350        return the residuals
351        """ 
352        if self.smearer != None:
353            fn.set_index(self.index_model)
354            # Get necessary data from self.data and set the data for smearing
355            fn.get_data()
356
357            gn = fn.get_value() 
358        else:
359            gn = fn([self.qx_data[self.index_model],
360                     self.qy_data[self.index_model]])
361        # use only the data point within ROI range
362        res = (self.data[self.index_model] - gn)/\
363                    self.res_err_data[self.index_model]
364        return res
365       
366    def residuals_deriv(self, model, pars=[]):
367        """
368        :return: residuals derivatives .
369       
370        :note: in this case just return empty array
371       
372        """
373        return []
374   
375class FitAbort(Exception):
376    """
377    Exception raise to stop the fit
378    """
379    #print"Creating fit abort Exception"
380
381
382class SansAssembly:
383    """
384    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
385    """
386    def __init__(self, paramlist, model=None , data=None, fitresult=None,
387                 handler=None, curr_thread=None):
388        """
389        :param Model: the model wrapper fro sans -model
390        :param Data: the data wrapper for sans data
391       
392        """
393        self.model = model
394        self.data  = data
395        self.paramlist = paramlist
396        self.curr_thread = curr_thread
397        self.handler = handler
398        self.fitresult = fitresult
399        self.res = []
400        self.func_name = "Functor"
401       
402    def chisq(self, params):
403        """
404        Calculates chi^2
405       
406        :param params: list of parameter values
407       
408        :return: chi^2
409       
410        """
411        sum = 0
412        for item in self.res:
413            sum += item*item
414        if len(self.res)==0:
415            return None
416        return sum/ len(self.res)
417   
418    def __call__(self,params):
419        """
420        Compute residuals
421       
422        :param params: value of parameters to fit
423       
424        """
425        #import thread
426        self.model.setParams(self.paramlist,params)
427        self.res = self.data.residuals(self.model.eval)
428        if self.fitresult is not None and  self.handler is not None:
429            self.fitresult.set_model(model=self.model)
430            fitness = self.chisq(params=params)
431            self.fitresult.set_fitness(fitness=fitness)
432            self.handler.set_result(result=self.fitresult)
433            self.handler.update_fit()
434       
435        #if self.curr_thread != None :
436        #    try:
437        #        self.curr_thread.isquit()
438        #    except:
439        #        raise FitAbort,"stop leastsqr optimizer"   
440        return self.res
441   
442class FitEngine:
443    def __init__(self):
444        """
445        Base class for scipy and park fit engine
446        """
447        #List of parameter names to fit
448        self.paramList=[]
449        #Dictionnary of fitArrange element (fit problems)
450        self.fitArrangeDict={}
451       
452    def _concatenateData(self, listdata=[]):
453        """ 
454        _concatenateData method concatenates each fields of all data
455        contains ins listdata.
456       
457        :param listdata: list of data
458       
459        :return Data: Data is wrapper class for sans plottable.
460         it is created with all parameters of data concatenanted
461           
462        :raise: if listdata is empty  will return None
463        :raise: if data in listdata don't contain dy field ,
464        will create an error during fitting
465           
466        """
467        #TODO: we have to refactor the way we handle data.
468        # We should move away from plottables and move towards the Data1D objects
469        # defined in DataLoader. Data1D allows data manipulations, which should be
470        # used to concatenate.
471        # In the meantime we should switch off the concatenation.
472        #if len(listdata)>1:
473        #    raise RuntimeError, "FitEngine._concatenateData: Multiple data files is not currently supported"
474        #return listdata[0]
475       
476        if listdata == []:
477            raise ValueError, " data list missing"
478        else:
479            xtemp = []
480            ytemp = []
481            dytemp = []
482            self.mini = None
483            self.maxi = None
484               
485            for item in listdata:
486                data = item.data
487                mini, maxi = data.getFitRange()
488                if self.mini == None and self.maxi == None:
489                    self.mini = mini
490                    self.maxi = maxi
491                else:
492                    if mini < self.mini:
493                        self.mini = mini
494                    if self.maxi < maxi:
495                        self.maxi = maxi
496                for i in range(len(data.x)):
497                    xtemp.append(data.x[i])
498                    ytemp.append(data.y[i])
499                    if data.dy is not None and len(data.dy) == len(data.y):   
500                        dytemp.append(data.dy[i])
501                    else:
502                        msg = "Fit._concatenateData: y-errors missing"
503                        raise RuntimeError, msg
504            data = Data(x=xtemp, y=ytemp, dy=dytemp)
505            data.setFitRange(self.mini, self.maxi)
506            return data
507       
508       
509    def set_model(self, model, Uid, pars=[], constraints=[]):
510        """
511        set a model on a given uid in the fit engine.
512       
513        :param model: sans.models type
514        :param Uid: is the key of the fitArrange dictionary where model is
515                saved as a value
516        :param pars: the list of parameters to fit
517        :param constraints: list of
518            tuple (name of parameter, value of parameters)
519            the value of parameter must be a string to constraint 2 different
520            parameters.
521            Example: 
522            we want to fit 2 model M1 and M2 both have parameters A and B.
523            constraints can be:
524             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
525           
526             
527        :note: pars must contains only name of existing model's parameters
528       
529        """
530        if model == None:
531            raise ValueError, "AbstractFitEngine: Need to set model to fit"
532       
533        new_model = model
534        if not issubclass(model.__class__, Model):
535            new_model = Model(model)
536       
537        if len(constraints) > 0:
538            for constraint in constraints:
539                name, value = constraint
540                try:
541                    new_model.parameterset[str(name)].set(str(value))
542                except:
543                    msg = "Fit Engine: Error occurs when setting the constraint"
544                    msg += " %s for parameter %s "%(value, name)
545                    raise ValueError, msg
546               
547        if len(pars) > 0:
548            temp = []
549            for item in pars:
550                if item in new_model.model.getParamList():
551                    temp.append(item)
552                    self.paramList.append(item)
553                else:
554                   
555                    msg = "wrong parameter %s used" % str(item)
556                    msg += "to set model %s. Choose" % str(new_model.model.name)
557                    msg += "parameter name within %s" % \
558                                str(new_model.model.getParamList())
559                    raise ValueError, msg
560             
561            #A fitArrange is already created but contains dList only at Uid
562            if self.fitArrangeDict.has_key(Uid):
563                self.fitArrangeDict[Uid].set_model(new_model)
564                self.fitArrangeDict[Uid].pars = pars
565            else:
566            #no fitArrange object has been create with this Uid
567                fitproblem = FitArrange()
568                fitproblem.set_model(new_model)
569                fitproblem.pars = pars
570                self.fitArrangeDict[Uid] = fitproblem
571               
572        else:
573            raise ValueError, "park_integration:missing parameters"
574   
575    def set_data(self, data, Uid, smearer=None, qmin=None, qmax=None):
576        """
577        Receives plottable, creates a list of data to fit,set data
578        in a FitArrange object and adds that object in a dictionary
579        with key Uid.
580       
581        :param data: data added
582        :param Uid: unique key corresponding to a fitArrange object with data
583       
584        """
585        if data.__class__.__name__ == 'Data2D':
586            fitdata = FitData2D(sans_data2d=data, data=data.data,
587                                 err_data=data.err_data)
588        else:
589            fitdata = FitData1D(x=data.x, y=data.y ,
590                                 dx=data.dx, dy=data.dy, smearer=smearer)
591       
592        fitdata.setFitRange(qmin=qmin, qmax=qmax)
593        #A fitArrange is already created but contains model only at Uid
594        if self.fitArrangeDict.has_key(Uid):
595            self.fitArrangeDict[Uid].add_data(fitdata)
596        else:
597        #no fitArrange object has been create with this Uid
598            fitproblem = FitArrange()
599            fitproblem.add_data(fitdata)
600            self.fitArrangeDict[Uid] = fitproblem   
601   
602    def get_model(self, Uid):
603        """
604       
605        :param Uid: Uid is key in the dictionary containing the model to return
606       
607        :return:  a model at this uid or None if no FitArrange element was created
608            with this Uid
609           
610        """
611        if self.fitArrangeDict.has_key(Uid):
612            return self.fitArrangeDict[Uid].get_model()
613        else:
614            return None
615   
616    def remove_Fit_Problem(self, Uid):
617        """remove   fitarrange in Uid"""
618        if self.fitArrangeDict.has_key(Uid):
619            del self.fitArrangeDict[Uid]
620           
621    def select_problem_for_fit(self, Uid, value):
622        """
623        select a couple of model and data at the Uid position in dictionary
624        and set in self.selected value to value
625       
626        :param value: the value to allow fitting.
627                can only have the value one or zero
628               
629        """
630        if self.fitArrangeDict.has_key(Uid):
631             self.fitArrangeDict[Uid].set_to_fit(value)
632             
633    def get_problem_to_fit(self, Uid):
634        """
635        return the self.selected value of the fit problem of Uid
636       
637        :param Uid: the Uid of the problem
638       
639        """
640        if self.fitArrangeDict.has_key(Uid):
641             self.fitArrangeDict[Uid].get_to_fit()
642   
643class FitArrange:
644    def __init__(self):
645        """
646        Class FitArrange contains a set of data for a given model
647        to perform the Fit.FitArrange must contain exactly one model
648        and at least one data for the fit to be performed.
649       
650        model: the model selected by the user
651        Ldata: a list of data what the user wants to fit
652           
653        """
654        self.model = None
655        self.dList = []
656        self.pars = []
657        #self.selected  is zero when this fit problem is not schedule to fit
658        #self.selected is 1 when schedule to fit
659        self.selected = 0
660       
661    def set_model(self, model):
662        """
663        set_model save a copy of the model
664       
665        :param model: the model being set
666       
667        """
668        self.model = model
669       
670    def add_data(self, data):
671        """
672        add_data fill a self.dList with data to fit
673       
674        :param data: Data to add in the list 
675       
676        """
677        if not data in self.dList:
678            self.dList.append(data)
679           
680    def get_model(self):
681        """
682       
683        :return: saved model
684       
685        """
686        return self.model   
687     
688    def get_data(self):
689        """
690       
691        :return: list of data dList
692       
693        """
694        #return self.dList
695        return self.dList[0] 
696     
697    def remove_data(self, data):
698        """
699        Remove one element from the list
700       
701        :param data: Data to remove from dList
702       
703        """
704        if data in self.dList:
705            self.dList.remove(data)
706           
707    def set_to_fit (self, value=0):
708        """
709        set self.selected to 0 or 1  for other values raise an exception
710       
711        :param value: integer between 0 or 1
712       
713        """
714        self.selected = value
715       
716    def get_to_fit(self):
717        """
718        return self.selected value
719        """
720        return self.selected
Note: See TracBrowser for help on using the repository browser.