source: sasview/park_integration/src/sans/fit/AbstractFitEngine.py @ 706667b

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 706667b was 1cff677, checked in by Gervaise Alina <gervyh@…>, 13 years ago

working on get thing data and model from result of fit

  • Property mode set to 100644
File size: 25.7 KB
Line 
1
2import  copy
3#import logging
4#import sys
5import numpy
6import math
7import park
8from sans.dataloader.data_info import Data1D
9from sans.dataloader.data_info import Data2D
10import time
11_SMALLVALUE = 1.0e-10   
12   
13class SansParameter(park.Parameter):
14    """
15    SANS model parameters for use in the PARK fitting service.
16    The parameter attribute value is redirected to the underlying
17    parameter value in the SANS model.
18    """
19    def __init__(self, name, model, data):
20        """
21        :param name: the name of the model parameter
22        :param model: the sans model to wrap as a park model
23       
24        """
25        park.Parameter.__init__(self, name)
26        self._model, self._name = model, name
27        self.data = data
28        self.model = model
29        #set the value for the parameter of the given name
30        self.set(model.getParam(name))
31         
32    def _getvalue(self):
33        """
34        override the _getvalue of park parameter
35       
36        :return value the parameter associates with self.name
37       
38        """
39        return self._model.getParam(self.name)
40   
41    def _setvalue(self, value):
42        """
43        override the _setvalue pf park parameter
44       
45        :param value: the value to set on a given parameter
46       
47        """
48        self._model.setParam(self.name, value)
49       
50    value = property(_getvalue, _setvalue)
51   
52    def _getrange(self):
53        """
54        Override _getrange of park parameter
55        return the range of parameter
56        """
57        #if not  self.name in self._model.getDispParamList():
58        lo, hi = self._model.details[self.name][1:3]
59        if lo is None: lo = -numpy.inf
60        if hi is None: hi = numpy.inf
61        #else:
62            #lo,hi = self._model.details[self.name][1:]
63            #if lo is None: lo = -numpy.inf
64            #if hi is None: hi = numpy.inf
65        if lo > hi:
66            raise ValueError,"wrong fit range for parameters"
67       
68        return lo, hi
69   
70    def get_name(self):
71        """
72        """
73        return self._getname()
74   
75    def _setrange(self, r):
76        """
77        override _setrange of park parameter
78       
79        :param r: the value of the range to set
80       
81        """
82        self._model.details[self.name][1:3] = r
83    range = property(_getrange, _setrange)
84   
85class Model(park.Model):
86    """
87    PARK wrapper for SANS models.
88    """
89    def __init__(self, sans_model, sans_data=None, **kw):
90        """
91        :param sans_model: the sans model to wrap using park interface
92       
93        """
94        park.Model.__init__(self, **kw)
95        self.model = sans_model
96        self.name = sans_model.name
97        self.data = sans_data
98        #list of parameters names
99        self.sansp = sans_model.getParamList()
100        #list of park parameter
101        self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp]
102        #list of parameterset
103        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
104        self.pars = []
105 
106    def get_params(self, fitparams):
107        """
108        return a list of value of paramter to fit
109       
110        :param fitparams: list of paramaters name to fit
111       
112        """
113        list_params = []
114        self.pars = []
115        self.pars = fitparams
116        for item in fitparams:
117            for element in self.parkp:
118                if element.name == str(item):
119                    list_params.append(element.value)
120        return list_params
121   
122    def set_params(self, paramlist, params):
123        """
124        Set value for parameters to fit
125       
126        :param params: list of value for parameters to fit
127       
128        """
129        try:
130            for i in range(len(self.parkp)):
131                for j in range(len(paramlist)):
132                    if self.parkp[i].name == paramlist[j]:
133                        self.parkp[i].value = params[j]
134                        self.model.setParam(self.parkp[i].name, params[j])
135        except:
136            raise
137 
138    def eval(self, x):
139        """
140        override eval method of park model.
141       
142        :param x: the x value used to compute a function
143       
144        """
145        try:
146            return self.model.evalDistribution(x)
147        except:
148            raise
149       
150    def eval_derivs(self, x, pars=[]):
151        """
152        Evaluate the model and derivatives wrt pars at x.
153
154        pars is a list of the names of the parameters for which derivatives
155        are desired.
156
157        This method needs to be specialized in the model to evaluate the
158        model function.  Alternatively, the model can implement is own
159        version of residuals which calculates the residuals directly
160        instead of calling eval.
161        """
162        return []
163
164
165   
166class FitData1D(Data1D):
167    """
168    Wrapper class  for SANS data
169    FitData1D inherits from DataLoader.data_info.Data1D. Implements
170    a way to get residuals from data.
171    """
172    def __init__(self, x, y, dx=None, dy=None, smearer=None):
173        """
174        :param smearer: is an object of class QSmearer or SlitSmearer
175           that will smear the theory data (slit smearing or resolution
176           smearing) when set.
177       
178        The proper way to set the smearing object would be to
179        do the following: ::
180       
181            from DataLoader.qsmearing import smear_selection
182            smearer = smear_selection(some_data)
183            fitdata1d = FitData1D( x= [1,3,..,],
184                                    y= [3,4,..,8],
185                                    dx=None,
186                                    dy=[1,2...], smearer= smearer)
187       
188        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
189            Setting it back to None will turn smearing off.
190           
191        """
192        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
193       
194        self.smearer = smearer
195        self._first_unsmeared_bin = None
196        self._last_unsmeared_bin = None
197        # Check error bar; if no error bar found, set it constant(=1)
198        # TODO: Should provide an option for users to set it like percent,
199        # constant, or dy data
200        if dy == None or dy == [] or dy.all() == 0:
201            self.dy = numpy.ones(len(y)) 
202        else:
203            self.dy = numpy.asarray(dy).copy()
204
205        ## Min Q-value
206        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
207        if min (self.x) == 0.0 and self.x[0] == 0 and\
208                     not numpy.isfinite(self.y[0]):
209            self.qmin = min(self.x[self.x!=0])
210        else:                             
211            self.qmin = min(self.x)
212        ## Max Q-value
213        self.qmax = max(self.x)
214       
215        # Range used for input to smearing
216        self._qmin_unsmeared = self.qmin
217        self._qmax_unsmeared = self.qmax
218        # Identify the bin range for the unsmeared and smeared spaces
219        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
220        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
221                            & (self.x <= self._qmax_unsmeared)
222 
223    def set_fit_range(self, qmin=None, qmax=None):
224        """ to set the fit range"""
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        # ToDo: Find better way to do it.
227        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
228            self.qmin = min(self.x[self.x != 0])
229        elif qmin != None:                       
230            self.qmin = qmin           
231        if qmax != None:
232            self.qmax = qmax
233        # Determine the range needed in unsmeared-Q to cover
234        # the smeared Q range
235        self._qmin_unsmeared = self.qmin
236        self._qmax_unsmeared = self.qmax   
237       
238        self._first_unsmeared_bin = 0
239        self._last_unsmeared_bin  = len(self.x) - 1
240       
241        if self.smearer != None:
242            self._first_unsmeared_bin, self._last_unsmeared_bin = \
243                    self.smearer.get_bin_range(self.qmin, self.qmax)
244            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
245            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
246           
247        # Identify the bin range for the unsmeared and smeared spaces
248        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
249        ## zero error can not participate for fitting
250        self.idx = self.idx & (self.dy != 0) 
251        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
252                            & (self.x <= self._qmax_unsmeared)
253 
254    def get_fit_range(self):
255        """
256        return the range of data.x to fit
257        """
258        return self.qmin, self.qmax
259       
260    def residuals(self, fn):
261        """
262        Compute residuals.
263       
264        If self.smearer has been set, use if to smear
265        the data before computing chi squared.
266       
267        :param fn: function that return model value
268       
269        :return: residuals
270       
271        """
272        # Compute theory data f(x)
273        fx = numpy.zeros(len(self.x))
274        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
275       
276        ## Smear theory data
277        if self.smearer is not None:
278            fx = self.smearer(fx, self._first_unsmeared_bin, 
279                              self._last_unsmeared_bin)
280        ## Sanity check
281        if numpy.size(self.dy) != numpy.size(fx):
282            msg = "FitData1D: invalid error array "
283            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))                                                     
284            raise RuntimeError, msg 
285        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx]
286     
287    def residuals_deriv(self, model, pars=[]):
288        """
289        :return: residuals derivatives .
290       
291        :note: in this case just return empty array
292       
293        """
294        return []
295   
296class FitData2D(Data2D):
297    """ Wrapper class  for SANS data """
298    def __init__(self, sans_data2d, data=None, err_data=None):
299        Data2D.__init__(self, data=data, err_data=err_data)
300        """
301        Data can be initital with a data (sans plottable)
302        or with vectors.
303        """
304        self.res_err_image = []
305        self.index_model = []
306        self.qmin = None
307        self.qmax = None
308        self.smearer = None
309        self.radius = 0
310        self.res_err_data = []
311        self.set_data(sans_data2d)
312
313    def set_data(self, sans_data2d, qmin=None, qmax=None):
314        """
315        Determine the correct qx_data and qy_data within range to fit
316        """
317        self.data = sans_data2d.data
318        self.err_data = sans_data2d.err_data
319        self.qx_data = sans_data2d.qx_data
320        self.qy_data = sans_data2d.qy_data
321        self.mask = sans_data2d.mask
322
323        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
324        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
325       
326        ## fitting range
327        if qmin == None:
328            self.qmin = 1e-16
329        if qmax == None:
330            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
331        ## new error image for fitting purpose
332        if self.err_data == None or self.err_data == []:
333            self.res_err_data = numpy.ones(len(self.data))
334        else:
335            self.res_err_data = copy.deepcopy(self.err_data)
336        #self.res_err_data[self.res_err_data==0]=1
337       
338        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
339       
340        # Note: mask = True: for MASK while mask = False for NOT to mask
341        self.index_model = ((self.qmin <= self.radius)&\
342                            (self.radius <= self.qmax))
343        self.index_model = (self.index_model) & (self.mask)
344        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
345       
346    def set_smearer(self, smearer): 
347        """
348        Set smearer
349        """
350        if smearer == None:
351            return
352        self.smearer = smearer
353        self.smearer.set_index(self.index_model)
354        self.smearer.get_data()
355
356    def set_fit_range(self, qmin=None, qmax=None):
357        """ to set the fit range"""
358        if qmin == 0.0:
359            self.qmin = 1e-16
360        elif qmin != None:                       
361            self.qmin = qmin           
362        if qmax != None:
363            self.qmax = qmax       
364        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
365        self.index_model = ((self.qmin <= self.radius)&\
366                            (self.radius <= self.qmax))
367        self.index_model = (self.index_model) &(self.mask)
368        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
369        self.index_model = (self.index_model) & (self.res_err_data != 0)
370       
371    def get_fit_range(self):
372        """
373        return the range of data.x to fit
374        """
375        return self.qmin, self.qmax
376     
377    def residuals(self, fn): 
378        """
379        return the residuals
380        """ 
381        if self.smearer != None:
382            fn.set_index(self.index_model)
383            # Get necessary data from self.data and set the data for smearing
384            fn.get_data()
385
386            gn = fn.get_value() 
387        else:
388            gn = fn([self.qx_data[self.index_model],
389                     self.qy_data[self.index_model]])
390        # use only the data point within ROI range
391        res = (self.data[self.index_model] - gn)/\
392                    self.res_err_data[self.index_model]
393        return res
394       
395    def residuals_deriv(self, model, pars=[]):
396        """
397        :return: residuals derivatives .
398       
399        :note: in this case just return empty array
400       
401        """
402        return []
403   
404class FitAbort(Exception):
405    """
406    Exception raise to stop the fit
407    """
408    #pass
409    #print"Creating fit abort Exception"
410
411
412class SansAssembly:
413    """
414    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
415    """
416    def __init__(self, paramlist, model=None , data=None, fitresult=None,
417                 handler=None, curr_thread=None):
418        """
419        :param Model: the model wrapper fro sans -model
420        :param Data: the data wrapper for sans data
421       
422        """
423        self.model = model
424        self.data  = data
425        self.paramlist = paramlist
426        self.curr_thread = curr_thread
427        self.handler = handler
428        self.fitresult = fitresult
429        self.res = []
430        self.true_res = []
431        self.func_name = "Functor"
432       
433    #def chisq(self, params):
434    def chisq(self):
435        """
436        Calculates chi^2
437       
438        :param params: list of parameter values
439       
440        :return: chi^2
441       
442        """
443        sum = 0
444        for item in self.true_res:
445            sum += item * item
446        if len(self.true_res) == 0:
447            return None
448        return sum / len(self.true_res)
449   
450    def __call__(self, params):
451        """
452        Compute residuals
453       
454        :param params: value of parameters to fit
455       
456        """ 
457        #import thread
458        self.model.set_params(self.paramlist, params)
459        self.true_res = self.data.residuals(self.model.eval)
460        # check parameters range
461        if self.check_param_range():
462            # if the param value is outside of the bound
463            # just silent return res = inf
464            return self.res
465        self.res = self.true_res       
466        if self.fitresult is not None and  self.handler is not None:
467            self.fitresult.set_model(model=self.model)
468            #fitness = self.chisq(params=params)
469            fitness = self.chisq()
470            self.fitresult.pvec = params
471            self.fitresult.set_fitness(fitness=fitness)
472            self.handler.set_result(result=self.fitresult)
473            self.handler.update_fit()
474
475            if self.curr_thread != None :
476                try:
477                    self.curr_thread.isquit()
478                except:
479                    msg = "Fitting: Terminated...       Note: Forcing to stop " 
480                    msg += "fitting may cause a 'Functor error message' "
481                    msg += "being recorded in the log file....."
482                    self.handler.error(msg)
483                    raise
484                    #return
485         
486        return self.res
487   
488    def check_param_range(self):
489        """
490        Check the lower and upper bound of the parameter value
491        and set res to the inf if the value is outside of the
492        range
493        :limitation: the initial values must be within range.
494        """
495
496        #time.sleep(0.01)
497        is_outofbound = False
498        # loop through the fit parameters
499        for p in self.model.parameterset:
500            param_name = p.get_name()
501            if param_name in self.paramlist:
502               
503                # if the range was defined, check the range
504                if numpy.isfinite(p.range[0]):
505                    if p.value == 0:
506                        # This value works on Scipy
507                        # Do not change numbers below
508                        value = _SMALLVALUE
509                    else:
510                        value = p.value
511                    # For leastsq, it needs a bit step back from the boundary
512                    val = p.range[0] - value * _SMALLVALUE
513                    if p.value < val: 
514                        self.res *= 1e+6
515                       
516                        is_outofbound = True
517                        break
518                if numpy.isfinite(p.range[1]):
519                    # This value works on Scipy
520                    # Do not change numbers below
521                    if p.value == 0:
522                        value = _SMALLVALUE
523                    else:
524                        value = p.value
525                    # For leastsq, it needs a bit step back from the boundary
526                    val = p.range[1] + value * _SMALLVALUE
527                    if p.value > val:
528                        self.res *= 1e+6
529                        is_outofbound = True
530                        break
531
532        return is_outofbound
533   
534   
535class FitEngine:
536    def __init__(self):
537        """
538        Base class for scipy and park fit engine
539        """
540        #List of parameter names to fit
541        self.param_list = []
542        #Dictionnary of fitArrange element (fit problems)
543        self.fit_arrange_dict = {}
544 
545    def set_model(self, model,  id,  pars=[], constraints=[], data=None):
546        """
547        set a model on a given  in the fit engine.
548       
549        :param model: sans.models type
550        :param : is the key of the fitArrange dictionary where model is
551                saved as a value
552        :param pars: the list of parameters to fit
553        :param constraints: list of
554            tuple (name of parameter, value of parameters)
555            the value of parameter must be a string to constraint 2 different
556            parameters.
557            Example: 
558            we want to fit 2 model M1 and M2 both have parameters A and B.
559            constraints can be:
560             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
561           
562             
563        :note: pars must contains only name of existing model's parameters
564       
565        """
566        if model == None:
567            raise ValueError, "AbstractFitEngine: Need to set model to fit"
568       
569        new_model = model
570        if not issubclass(model.__class__, Model):
571            new_model = Model(model, data)
572       
573        if len(constraints) > 0:
574            for constraint in constraints:
575                name, value = constraint
576                try:
577                    new_model.parameterset[str(name)].set(str(value))
578                except:
579                    msg = "Fit Engine: Error occurs when setting the constraint"
580                    msg += " %s for parameter %s " % (value, name)
581                    raise ValueError, msg
582               
583        if len(pars) > 0:
584            temp = []
585            for item in pars:
586                if item in new_model.model.getParamList():
587                    temp.append(item)
588                    self.param_list.append(item)
589                else:
590                   
591                    msg = "wrong parameter %s used" % str(item)
592                    msg += "to set model %s. Choose" % str(new_model.model.name)
593                    msg += "parameter name within %s" % \
594                                str(new_model.model.getParamList())
595                    raise ValueError, msg
596             
597            #A fitArrange is already created but contains data_list only at id
598            if self.fit_arrange_dict.has_key(id):
599                self.fit_arrange_dict[id].set_model(new_model)
600                self.fit_arrange_dict[id].pars = pars
601            else:
602            #no fitArrange object has been create with this id
603                fitproblem = FitArrange()
604                fitproblem.set_model(new_model)
605                fitproblem.pars = pars
606                self.fit_arrange_dict[id] = fitproblem
607               
608        else:
609            raise ValueError, "park_integration:missing parameters"
610   
611    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
612        """
613        Receives plottable, creates a list of data to fit,set data
614        in a FitArrange object and adds that object in a dictionary
615        with key id.
616       
617        :param data: data added
618        :param id: unique key corresponding to a fitArrange object with data
619       
620        """
621        if data.__class__.__name__ == 'Data2D':
622            fitdata = FitData2D(sans_data2d=data, data=data.data,
623                                 err_data=data.err_data)
624        else:
625            fitdata = FitData1D(x=data.x, y=data.y ,
626                                 dx=data.dx, dy=data.dy, smearer=smearer)
627       
628        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
629        #A fitArrange is already created but contains model only at id
630        if self.fit_arrange_dict.has_key(id):
631            self.fit_arrange_dict[id].add_data(fitdata)
632        else:
633        #no fitArrange object has been create with this id
634            fitproblem = FitArrange()
635            fitproblem.add_data(fitdata)
636            self.fit_arrange_dict[id] = fitproblem   
637   
638    def get_model(self, id):
639        """
640       
641        :param id: id is key in the dictionary containing the model to return
642       
643        :return:  a model at this id or None if no FitArrange element was
644            created with this id
645           
646        """
647        if self.fit_arrange_dict.has_key(id):
648            return self.fit_arrange_dict[id].get_model()
649        else:
650            return None
651   
652    def remove_fit_problem(self, id):
653        """remove   fitarrange in id"""
654        if self.fit_arrange_dict.has_key(id):
655            del self.fit_arrange_dict[id]
656           
657    def select_problem_for_fit(self, id, value):
658        """
659        select a couple of model and data at the id position in dictionary
660        and set in self.selected value to value
661       
662        :param value: the value to allow fitting.
663                can only have the value one or zero
664               
665        """
666        if self.fit_arrange_dict.has_key(id):
667            self.fit_arrange_dict[id].set_to_fit(value)
668             
669    def get_problem_to_fit(self, id):
670        """
671        return the self.selected value of the fit problem of id
672       
673        :param id: the id of the problem
674       
675        """
676        if self.fit_arrange_dict.has_key(id):
677            self.fit_arrange_dict[id].get_to_fit()
678   
679class FitArrange:
680    def __init__(self):
681        """
682        Class FitArrange contains a set of data for a given model
683        to perform the Fit.FitArrange must contain exactly one model
684        and at least one data for the fit to be performed.
685       
686        model: the model selected by the user
687        Ldata: a list of data what the user wants to fit
688           
689        """
690        self.model = None
691        self.data_list = []
692        self.pars = []
693        #self.selected  is zero when this fit problem is not schedule to fit
694        #self.selected is 1 when schedule to fit
695        self.selected = 0
696       
697    def set_model(self, model):
698        """
699        set_model save a copy of the model
700       
701        :param model: the model being set
702       
703        """
704        self.model = model
705       
706    def add_data(self, data):
707        """
708        add_data fill a self.data_list with data to fit
709       
710        :param data: Data to add in the list 
711       
712        """
713        if not data in self.data_list:
714            self.data_list.append(data)
715           
716    def get_model(self):
717        """
718       
719        :return: saved model
720       
721        """
722        return self.model   
723     
724    def get_data(self):
725        """
726       
727        :return: list of data data_list
728       
729        """
730        #return self.data_list
731        return self.data_list[0] 
732     
733    def remove_data(self, data):
734        """
735        Remove one element from the list
736       
737        :param data: Data to remove from data_list
738       
739        """
740        if data in self.data_list:
741            self.data_list.remove(data)
742           
743    def set_to_fit (self, value=0):
744        """
745        set self.selected to 0 or 1  for other values raise an exception
746       
747        :param value: integer between 0 or 1
748       
749        """
750        self.selected = value
751       
752    def get_to_fit(self):
753        """
754        return self.selected value
755        """
756        return self.selected
Note: See TracBrowser for help on using the repository browser.