source: sasview/park_integration/src/sans/fit/AbstractFitEngine.py @ 1c86a37

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 1c86a37 was 3dcbe99, checked in by Jae Cho <jhjcho@…>, 13 years ago

2d fit fix

  • Property mode set to 100644
File size: 25.9 KB
Line 
1
2import  copy
3#import logging
4#import sys
5import numpy
6import math
7import park
8from sans.dataloader.data_info import Data1D
9from sans.dataloader.data_info import Data2D
10import time
11_SMALLVALUE = 1.0e-10   
12   
13class SansParameter(park.Parameter):
14    """
15    SANS model parameters for use in the PARK fitting service.
16    The parameter attribute value is redirected to the underlying
17    parameter value in the SANS model.
18    """
19    def __init__(self, name, model, data):
20        """
21        :param name: the name of the model parameter
22        :param model: the sans model to wrap as a park model
23       
24        """
25        park.Parameter.__init__(self, name)
26        self._model, self._name = model, name
27        self.data = data
28        self.model = model
29        #set the value for the parameter of the given name
30        self.set(model.getParam(name))
31         
32    def _getvalue(self):
33        """
34        override the _getvalue of park parameter
35       
36        :return value the parameter associates with self.name
37       
38        """
39        return self._model.getParam(self.name)
40   
41    def _setvalue(self, value):
42        """
43        override the _setvalue pf park parameter
44       
45        :param value: the value to set on a given parameter
46       
47        """
48        self._model.setParam(self.name, value)
49       
50    value = property(_getvalue, _setvalue)
51   
52    def _getrange(self):
53        """
54        Override _getrange of park parameter
55        return the range of parameter
56        """
57        #if not  self.name in self._model.getDispParamList():
58        lo, hi = self._model.details[self.name][1:3]
59        if lo is None: lo = -numpy.inf
60        if hi is None: hi = numpy.inf
61        #else:
62            #lo,hi = self._model.details[self.name][1:]
63            #if lo is None: lo = -numpy.inf
64            #if hi is None: hi = numpy.inf
65        if lo > hi:
66            raise ValueError,"wrong fit range for parameters"
67       
68        return lo, hi
69   
70    def get_name(self):
71        """
72        """
73        return self._getname()
74   
75    def _setrange(self, r):
76        """
77        override _setrange of park parameter
78       
79        :param r: the value of the range to set
80       
81        """
82        self._model.details[self.name][1:3] = r
83    range = property(_getrange, _setrange)
84   
85class Model(park.Model):
86    """
87    PARK wrapper for SANS models.
88    """
89    def __init__(self, sans_model, sans_data=None, **kw):
90        """
91        :param sans_model: the sans model to wrap using park interface
92       
93        """
94        park.Model.__init__(self, **kw)
95        self.model = sans_model
96        self.name = sans_model.name
97        self.data = sans_data
98        #list of parameters names
99        self.sansp = sans_model.getParamList()
100        #list of park parameter
101        self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp]
102        #list of parameterset
103        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
104        self.pars = []
105 
106    def get_params(self, fitparams):
107        """
108        return a list of value of paramter to fit
109       
110        :param fitparams: list of paramaters name to fit
111       
112        """
113        list_params = []
114        self.pars = []
115        self.pars = fitparams
116        for item in fitparams:
117            for element in self.parkp:
118                if element.name == str(item):
119                    list_params.append(element.value)
120        return list_params
121   
122    def set_params(self, paramlist, params):
123        """
124        Set value for parameters to fit
125       
126        :param params: list of value for parameters to fit
127       
128        """
129        try:
130            for i in range(len(self.parkp)):
131                for j in range(len(paramlist)):
132                    if self.parkp[i].name == paramlist[j]:
133                        self.parkp[i].value = params[j]
134                        self.model.setParam(self.parkp[i].name, params[j])
135        except:
136            raise
137 
138    def eval(self, x):
139        """
140        override eval method of park model.
141       
142        :param x: the x value used to compute a function
143       
144        """
145        try:
146            return self.model.evalDistribution(x)
147        except:
148            raise
149       
150    def eval_derivs(self, x, pars=[]):
151        """
152        Evaluate the model and derivatives wrt pars at x.
153
154        pars is a list of the names of the parameters for which derivatives
155        are desired.
156
157        This method needs to be specialized in the model to evaluate the
158        model function.  Alternatively, the model can implement is own
159        version of residuals which calculates the residuals directly
160        instead of calling eval.
161        """
162        return []
163
164
165   
166class FitData1D(Data1D):
167    """
168    Wrapper class  for SANS data
169    FitData1D inherits from DataLoader.data_info.Data1D. Implements
170    a way to get residuals from data.
171    """
172    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None):
173        """
174        :param smearer: is an object of class QSmearer or SlitSmearer
175           that will smear the theory data (slit smearing or resolution
176           smearing) when set.
177       
178        The proper way to set the smearing object would be to
179        do the following: ::
180       
181            from DataLoader.qsmearing import smear_selection
182            smearer = smear_selection(some_data)
183            fitdata1d = FitData1D( x= [1,3,..,],
184                                    y= [3,4,..,8],
185                                    dx=None,
186                                    dy=[1,2...], smearer= smearer)
187       
188        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
189            Setting it back to None will turn smearing off.
190           
191        """
192        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
193        self.sans_data = data
194        self.smearer = smearer
195        self._first_unsmeared_bin = None
196        self._last_unsmeared_bin = None
197        # Check error bar; if no error bar found, set it constant(=1)
198        # TODO: Should provide an option for users to set it like percent,
199        # constant, or dy data
200        if dy == None or dy == [] or dy.all() == 0:
201            self.dy = numpy.ones(len(y)) 
202        else:
203            self.dy = numpy.asarray(dy).copy()
204
205        ## Min Q-value
206        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
207        if min (self.x) == 0.0 and self.x[0] == 0 and\
208                     not numpy.isfinite(self.y[0]):
209            self.qmin = min(self.x[self.x!=0])
210        else:                             
211            self.qmin = min(self.x)
212        ## Max Q-value
213        self.qmax = max(self.x)
214       
215        # Range used for input to smearing
216        self._qmin_unsmeared = self.qmin
217        self._qmax_unsmeared = self.qmax
218        # Identify the bin range for the unsmeared and smeared spaces
219        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
220        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
221                            & (self.x <= self._qmax_unsmeared)
222 
223    def set_fit_range(self, qmin=None, qmax=None):
224        """ to set the fit range"""
225        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
226        # ToDo: Find better way to do it.
227        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
228            self.qmin = min(self.x[self.x != 0])
229        elif qmin != None:                       
230            self.qmin = qmin           
231        if qmax != None:
232            self.qmax = qmax
233        # Determine the range needed in unsmeared-Q to cover
234        # the smeared Q range
235        self._qmin_unsmeared = self.qmin
236        self._qmax_unsmeared = self.qmax   
237       
238        self._first_unsmeared_bin = 0
239        self._last_unsmeared_bin  = len(self.x) - 1
240       
241        if self.smearer != None:
242            self._first_unsmeared_bin, self._last_unsmeared_bin = \
243                    self.smearer.get_bin_range(self.qmin, self.qmax)
244            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
245            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
246           
247        # Identify the bin range for the unsmeared and smeared spaces
248        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
249        ## zero error can not participate for fitting
250        self.idx = self.idx & (self.dy != 0) 
251        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
252                            & (self.x <= self._qmax_unsmeared)
253 
254    def get_fit_range(self):
255        """
256        return the range of data.x to fit
257        """
258        return self.qmin, self.qmax
259       
260    def residuals(self, fn):
261        """
262        Compute residuals.
263       
264        If self.smearer has been set, use if to smear
265        the data before computing chi squared.
266       
267        :param fn: function that return model value
268       
269        :return: residuals
270       
271        """
272        # Compute theory data f(x)
273        fx = numpy.zeros(len(self.x))
274        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
275       
276        ## Smear theory data
277        if self.smearer is not None:
278            fx = self.smearer(fx, self._first_unsmeared_bin, 
279                              self._last_unsmeared_bin)
280        ## Sanity check
281        if numpy.size(self.dy) != numpy.size(fx):
282            msg = "FitData1D: invalid error array "
283            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))                                                     
284            raise RuntimeError, msg 
285        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
286     
287    def residuals_deriv(self, model, pars=[]):
288        """
289        :return: residuals derivatives .
290       
291        :note: in this case just return empty array
292       
293        """
294        return []
295   
296class FitData2D(Data2D):
297    """ Wrapper class  for SANS data """
298    def __init__(self, sans_data2d, data=None, err_data=None):
299        Data2D.__init__(self, data=data, err_data=err_data)
300        """
301        Data can be initital with a data (sans plottable)
302        or with vectors.
303        """
304        self.res_err_image = []
305        self.index_model = []
306        self.qmin = None
307        self.qmax = None
308        self.smearer = None
309        self.radius = 0
310        self.res_err_data = []
311        self.sans_data = sans_data2d
312        self.set_data(sans_data2d)
313
314    def set_data(self, sans_data2d, qmin=None, qmax=None):
315        """
316        Determine the correct qx_data and qy_data within range to fit
317        """
318        self.data = sans_data2d.data
319        self.err_data = sans_data2d.err_data
320        self.qx_data = sans_data2d.qx_data
321        self.qy_data = sans_data2d.qy_data
322        self.mask = sans_data2d.mask
323
324        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
325        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
326       
327        ## fitting range
328        if qmin == None:
329            self.qmin = 1e-16
330        if qmax == None:
331            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
332        ## new error image for fitting purpose
333        if self.err_data == None or self.err_data == []:
334            self.res_err_data = numpy.ones(len(self.data))
335        else:
336            self.res_err_data = copy.deepcopy(self.err_data)
337        #self.res_err_data[self.res_err_data==0]=1
338       
339        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
340       
341        # Note: mask = True: for MASK while mask = False for NOT to mask
342        self.index_model = ((self.qmin <= self.radius)&\
343                            (self.radius <= self.qmax))
344        self.index_model = (self.index_model) & (self.mask)
345        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
346       
347    def set_smearer(self, smearer): 
348        """
349        Set smearer
350        """
351        if smearer == None:
352            return
353        self.smearer = smearer
354        self.smearer.set_index(self.index_model)
355        self.smearer.get_data()
356
357    def set_fit_range(self, qmin=None, qmax=None):
358        """ to set the fit range"""
359        if qmin == 0.0:
360            self.qmin = 1e-16
361        elif qmin != None:                       
362            self.qmin = qmin           
363        if qmax != None:
364            self.qmax = qmax       
365        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
366        self.index_model = ((self.qmin <= self.radius)&\
367                            (self.radius <= self.qmax))
368        self.index_model = (self.index_model) &(self.mask)
369        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
370        self.index_model = (self.index_model) & (self.res_err_data != 0)
371       
372    def get_fit_range(self):
373        """
374        return the range of data.x to fit
375        """
376        return self.qmin, self.qmax
377     
378    def residuals(self, fn): 
379        """
380        return the residuals
381        """ 
382        if self.smearer != None:
383            fn.set_index(self.index_model)
384            # Get necessary data from self.data and set the data for smearing
385            fn.get_data()
386
387            gn = fn.get_value() 
388        else:
389            gn = fn([self.qx_data[self.index_model],
390                     self.qy_data[self.index_model]])
391        # use only the data point within ROI range
392        res = (self.data[self.index_model] - gn)/\
393                    self.res_err_data[self.index_model]
394        return res, gn
395       
396    def residuals_deriv(self, model, pars=[]):
397        """
398        :return: residuals derivatives .
399       
400        :note: in this case just return empty array
401       
402        """
403        return []
404   
405class FitAbort(Exception):
406    """
407    Exception raise to stop the fit
408    """
409    #pass
410    #print"Creating fit abort Exception"
411
412
413class SansAssembly:
414    """
415    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
416    """
417    def __init__(self, paramlist, model=None , data=None, fitresult=None,
418                 handler=None, curr_thread=None):
419        """
420        :param Model: the model wrapper fro sans -model
421        :param Data: the data wrapper for sans data
422       
423        """
424        self.model = model
425        self.data  = data
426        self.paramlist = paramlist
427        self.curr_thread = curr_thread
428        self.handler = handler
429        self.fitresult = fitresult
430        self.res = []
431        self.true_res = []
432        self.func_name = "Functor"
433        self.theory = None
434       
435    #def chisq(self, params):
436    def chisq(self):
437        """
438        Calculates chi^2
439       
440        :param params: list of parameter values
441       
442        :return: chi^2
443       
444        """
445        sum = 0
446        for item in self.true_res:
447            sum += item * item
448        if len(self.true_res) == 0:
449            return None
450        return sum / len(self.true_res)
451   
452    def __call__(self, params):
453        """
454        Compute residuals
455       
456        :param params: value of parameters to fit
457       
458        """ 
459        #import thread
460        self.model.set_params(self.paramlist, params)
461       
462        self.true_res, self.theory = self.data.residuals(self.model.eval)
463        # check parameters range
464        if self.check_param_range():
465            # if the param value is outside of the bound
466            # just silent return res = inf
467            return self.res
468        self.res = self.true_res       
469        if self.fitresult is not None and  self.handler is not None:
470            self.fitresult.set_model(model=self.model)
471            #fitness = self.chisq(params=params)
472            fitness = self.chisq()
473            self.fitresult.pvec = params
474            self.fitresult.set_fitness(fitness=fitness)
475            self.handler.set_result(result=self.fitresult)
476            self.handler.update_fit()
477
478            if self.curr_thread != None :
479                try:
480                    self.curr_thread.isquit()
481                except:
482                    msg = "Fitting: Terminated...       Note: Forcing to stop " 
483                    msg += "fitting may cause a 'Functor error message' "
484                    msg += "being recorded in the log file....."
485                    self.handler.error(msg)
486                    raise
487                    #return
488         
489        return self.res
490   
491    def check_param_range(self):
492        """
493        Check the lower and upper bound of the parameter value
494        and set res to the inf if the value is outside of the
495        range
496        :limitation: the initial values must be within range.
497        """
498
499        #time.sleep(0.01)
500        is_outofbound = False
501        # loop through the fit parameters
502        for p in self.model.parameterset:
503            param_name = p.get_name()
504            if param_name in self.paramlist:
505               
506                # if the range was defined, check the range
507                if numpy.isfinite(p.range[0]):
508                    if p.value == 0:
509                        # This value works on Scipy
510                        # Do not change numbers below
511                        value = _SMALLVALUE
512                    else:
513                        value = p.value
514                    # For leastsq, it needs a bit step back from the boundary
515                    val = p.range[0] - value * _SMALLVALUE
516                    if p.value < val: 
517                        self.res *= 1e+6
518                       
519                        is_outofbound = True
520                        break
521                if numpy.isfinite(p.range[1]):
522                    # This value works on Scipy
523                    # Do not change numbers below
524                    if p.value == 0:
525                        value = _SMALLVALUE
526                    else:
527                        value = p.value
528                    # For leastsq, it needs a bit step back from the boundary
529                    val = p.range[1] + value * _SMALLVALUE
530                    if p.value > val:
531                        self.res *= 1e+6
532                        is_outofbound = True
533                        break
534
535        return is_outofbound
536   
537   
538class FitEngine:
539    def __init__(self):
540        """
541        Base class for scipy and park fit engine
542        """
543        #List of parameter names to fit
544        self.param_list = []
545        #Dictionnary of fitArrange element (fit problems)
546        self.fit_arrange_dict = {}
547 
548    def set_model(self, model,  id,  pars=[], constraints=[], data=None):
549        """
550        set a model on a given  in the fit engine.
551       
552        :param model: sans.models type
553        :param : is the key of the fitArrange dictionary where model is
554                saved as a value
555        :param pars: the list of parameters to fit
556        :param constraints: list of
557            tuple (name of parameter, value of parameters)
558            the value of parameter must be a string to constraint 2 different
559            parameters.
560            Example: 
561            we want to fit 2 model M1 and M2 both have parameters A and B.
562            constraints can be:
563             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
564           
565             
566        :note: pars must contains only name of existing model's parameters
567       
568        """
569        if model == None:
570            raise ValueError, "AbstractFitEngine: Need to set model to fit"
571       
572        new_model = model
573        if not issubclass(model.__class__, Model):
574            new_model = Model(model, data)
575       
576        if len(constraints) > 0:
577            for constraint in constraints:
578                name, value = constraint
579                try:
580                    new_model.parameterset[str(name)].set(str(value))
581                except:
582                    msg = "Fit Engine: Error occurs when setting the constraint"
583                    msg += " %s for parameter %s " % (value, name)
584                    raise ValueError, msg
585               
586        if len(pars) > 0:
587            temp = []
588            for item in pars:
589                if item in new_model.model.getParamList():
590                    temp.append(item)
591                    self.param_list.append(item)
592                else:
593                   
594                    msg = "wrong parameter %s used" % str(item)
595                    msg += "to set model %s. Choose" % str(new_model.model.name)
596                    msg += "parameter name within %s" % \
597                                str(new_model.model.getParamList())
598                    raise ValueError, msg
599             
600            #A fitArrange is already created but contains data_list only at id
601            if self.fit_arrange_dict.has_key(id):
602                self.fit_arrange_dict[id].set_model(new_model)
603                self.fit_arrange_dict[id].pars = pars
604            else:
605            #no fitArrange object has been create with this id
606                fitproblem = FitArrange()
607                fitproblem.set_model(new_model)
608                fitproblem.pars = pars
609                self.fit_arrange_dict[id] = fitproblem
610               
611        else:
612            raise ValueError, "park_integration:missing parameters"
613   
614    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
615        """
616        Receives plottable, creates a list of data to fit,set data
617        in a FitArrange object and adds that object in a dictionary
618        with key id.
619       
620        :param data: data added
621        :param id: unique key corresponding to a fitArrange object with data
622       
623        """
624        if data.__class__.__name__ == 'Data2D':
625            fitdata = FitData2D(sans_data2d=data, data=data.data,
626                                 err_data=data.err_data)
627        else:
628            fitdata = FitData1D(x=data.x, y=data.y ,
629                                 dx=data.dx, dy=data.dy, smearer=smearer)
630        fitdata.sans_data = data
631       
632        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
633        #A fitArrange is already created but contains model only at id
634        if self.fit_arrange_dict.has_key(id):
635            self.fit_arrange_dict[id].add_data(fitdata)
636        else:
637        #no fitArrange object has been create with this id
638            fitproblem = FitArrange()
639            fitproblem.add_data(fitdata)
640            self.fit_arrange_dict[id] = fitproblem   
641   
642    def get_model(self, id):
643        """
644       
645        :param id: id is key in the dictionary containing the model to return
646       
647        :return:  a model at this id or None if no FitArrange element was
648            created with this id
649           
650        """
651        if self.fit_arrange_dict.has_key(id):
652            return self.fit_arrange_dict[id].get_model()
653        else:
654            return None
655   
656    def remove_fit_problem(self, id):
657        """remove   fitarrange in id"""
658        if self.fit_arrange_dict.has_key(id):
659            del self.fit_arrange_dict[id]
660           
661    def select_problem_for_fit(self, id, value):
662        """
663        select a couple of model and data at the id position in dictionary
664        and set in self.selected value to value
665       
666        :param value: the value to allow fitting.
667                can only have the value one or zero
668               
669        """
670        if self.fit_arrange_dict.has_key(id):
671            self.fit_arrange_dict[id].set_to_fit(value)
672             
673    def get_problem_to_fit(self, id):
674        """
675        return the self.selected value of the fit problem of id
676       
677        :param id: the id of the problem
678       
679        """
680        if self.fit_arrange_dict.has_key(id):
681            self.fit_arrange_dict[id].get_to_fit()
682   
683class FitArrange:
684    def __init__(self):
685        """
686        Class FitArrange contains a set of data for a given model
687        to perform the Fit.FitArrange must contain exactly one model
688        and at least one data for the fit to be performed.
689       
690        model: the model selected by the user
691        Ldata: a list of data what the user wants to fit
692           
693        """
694        self.model = None
695        self.data_list = []
696        self.pars = []
697        #self.selected  is zero when this fit problem is not schedule to fit
698        #self.selected is 1 when schedule to fit
699        self.selected = 0
700       
701    def set_model(self, model):
702        """
703        set_model save a copy of the model
704       
705        :param model: the model being set
706       
707        """
708        self.model = model
709       
710    def add_data(self, data):
711        """
712        add_data fill a self.data_list with data to fit
713       
714        :param data: Data to add in the list 
715       
716        """
717        if not data in self.data_list:
718            self.data_list.append(data)
719           
720    def get_model(self):
721        """
722       
723        :return: saved model
724       
725        """
726        return self.model   
727     
728    def get_data(self):
729        """
730       
731        :return: list of data data_list
732       
733        """
734        #return self.data_list
735        return self.data_list[0] 
736     
737    def remove_data(self, data):
738        """
739        Remove one element from the list
740       
741        :param data: Data to remove from data_list
742       
743        """
744        if data in self.data_list:
745            self.data_list.remove(data)
746           
747    def set_to_fit (self, value=0):
748        """
749        set self.selected to 0 or 1  for other values raise an exception
750       
751        :param value: integer between 0 or 1
752       
753        """
754        self.selected = value
755       
756    def get_to_fit(self):
757        """
758        return self.selected value
759        """
760        return self.selected
Note: See TracBrowser for help on using the repository browser.