source: sasview/park_integration/AbstractFitEngine.py @ 40953a9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 40953a9 was 92697e3, checked in by Jae Cho <jhjcho@…>, 13 years ago

minor msg change

  • Property mode set to 100644
File size: 25.3 KB
Line 
1
2import  copy
3#import logging
4#import sys
5import numpy
6import math
7import park
8from DataLoader.data_info import Data1D
9from DataLoader.data_info import Data2D
10import time
11_SMALLVALUE = 1.0e-10   
12   
13class SansParameter(park.Parameter):
14    """
15    SANS model parameters for use in the PARK fitting service.
16    The parameter attribute value is redirected to the underlying
17    parameter value in the SANS model.
18    """
19    def __init__(self, name, model):
20        """
21        :param name: the name of the model parameter
22        :param model: the sans model to wrap as a park model
23       
24        """
25        park.Parameter.__init__(self, name)
26        self._model, self._name = model, name
27        #set the value for the parameter of the given name
28        self.set(model.getParam(name))
29         
30    def _getvalue(self):
31        """
32        override the _getvalue of park parameter
33       
34        :return value the parameter associates with self.name
35       
36        """
37        return self._model.getParam(self.name)
38   
39    def _setvalue(self, value):
40        """
41        override the _setvalue pf park parameter
42       
43        :param value: the value to set on a given parameter
44       
45        """
46        self._model.setParam(self.name, value)
47       
48    value = property(_getvalue, _setvalue)
49   
50    def _getrange(self):
51        """
52        Override _getrange of park parameter
53        return the range of parameter
54        """
55        #if not  self.name in self._model.getDispParamList():
56        lo, hi = self._model.details[self.name][1:3]
57        if lo is None: lo = -numpy.inf
58        if hi is None: hi = numpy.inf
59        #else:
60            #lo,hi = self._model.details[self.name][1:]
61            #if lo is None: lo = -numpy.inf
62            #if hi is None: hi = numpy.inf
63        if lo >= hi:
64            raise ValueError,"wrong fit range for parameters"
65       
66        return lo, hi
67   
68    def get_name(self):
69        """
70        """
71        return self._getname()
72   
73    def _setrange(self, r):
74        """
75        override _setrange of park parameter
76       
77        :param r: the value of the range to set
78       
79        """
80        self._model.details[self.name][1:3] = r
81    range = property(_getrange, _setrange)
82   
83class Model(park.Model):
84    """
85    PARK wrapper for SANS models.
86    """
87    def __init__(self, sans_model, **kw):
88        """
89        :param sans_model: the sans model to wrap using park interface
90       
91        """
92        park.Model.__init__(self, **kw)
93        self.model = sans_model
94        self.name = sans_model.name
95        #list of parameters names
96        self.sansp = sans_model.getParamList()
97        #list of park parameter
98        self.parkp = [SansParameter(p, sans_model) for p in self.sansp]
99        #list of parameterset
100        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
101        self.pars = []
102 
103    def get_params(self, fitparams):
104        """
105        return a list of value of paramter to fit
106       
107        :param fitparams: list of paramaters name to fit
108       
109        """
110        list_params = []
111        self.pars = []
112        self.pars = fitparams
113        for item in fitparams:
114            for element in self.parkp:
115                if element.name == str(item):
116                    list_params.append(element.value)
117        return list_params
118   
119    def set_params(self, paramlist, params):
120        """
121        Set value for parameters to fit
122       
123        :param params: list of value for parameters to fit
124       
125        """
126        try:
127            for i in range(len(self.parkp)):
128                for j in range(len(paramlist)):
129                    if self.parkp[i].name == paramlist[j]:
130                        self.parkp[i].value = params[j]
131                        self.model.setParam(self.parkp[i].name, params[j])
132        except:
133            raise
134 
135    def eval(self, x):
136        """
137        override eval method of park model.
138       
139        :param x: the x value used to compute a function
140       
141        """
142        try:
143            return self.model.evalDistribution(x)
144        except:
145            raise
146       
147    def eval_derivs(self, x, pars=[]):
148        """
149        Evaluate the model and derivatives wrt pars at x.
150
151        pars is a list of the names of the parameters for which derivatives
152        are desired.
153
154        This method needs to be specialized in the model to evaluate the
155        model function.  Alternatively, the model can implement is own
156        version of residuals which calculates the residuals directly
157        instead of calling eval.
158        """
159        return []
160
161
162   
163class FitData1D(Data1D):
164    """
165    Wrapper class  for SANS data
166    FitData1D inherits from DataLoader.data_info.Data1D. Implements
167    a way to get residuals from data.
168    """
169    def __init__(self, x, y, dx=None, dy=None, smearer=None):
170        """
171        :param smearer: is an object of class QSmearer or SlitSmearer
172           that will smear the theory data (slit smearing or resolution
173           smearing) when set.
174       
175        The proper way to set the smearing object would be to
176        do the following: ::
177       
178            from DataLoader.qsmearing import smear_selection
179            smearer = smear_selection(some_data)
180            fitdata1d = FitData1D( x= [1,3,..,],
181                                    y= [3,4,..,8],
182                                    dx=None,
183                                    dy=[1,2...], smearer= smearer)
184       
185        :Note: that some_data _HAS_ to be of class DataLoader.data_info.Data1D
186            Setting it back to None will turn smearing off.
187           
188        """
189        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
190       
191        self.smearer = smearer
192        self._first_unsmeared_bin = None
193        self._last_unsmeared_bin = None
194        # Check error bar; if no error bar found, set it constant(=1)
195        # TODO: Should provide an option for users to set it like percent,
196        # constant, or dy data
197        if dy == None or dy == [] or dy.all() == 0:
198            self.dy = numpy.ones(len(y)) 
199        else:
200            self.dy = numpy.asarray(dy).copy()
201
202        ## Min Q-value
203        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
204        if min (self.x) == 0.0 and self.x[0] == 0 and\
205                     not numpy.isfinite(self.y[0]):
206            self.qmin = min(self.x[self.x!=0])
207        else:                             
208            self.qmin = min(self.x)
209        ## Max Q-value
210        self.qmax = max(self.x)
211       
212        # Range used for input to smearing
213        self._qmin_unsmeared = self.qmin
214        self._qmax_unsmeared = self.qmax
215        # Identify the bin range for the unsmeared and smeared spaces
216        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
217        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
218                            & (self.x <= self._qmax_unsmeared)
219 
220    def set_fit_range(self, qmin=None, qmax=None):
221        """ to set the fit range"""
222        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
223        # ToDo: Find better way to do it.
224        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
225            self.qmin = min(self.x[self.x != 0])
226        elif qmin != None:                       
227            self.qmin = qmin           
228        if qmax != None:
229            self.qmax = qmax
230        # Determine the range needed in unsmeared-Q to cover
231        # the smeared Q range
232        self._qmin_unsmeared = self.qmin
233        self._qmax_unsmeared = self.qmax   
234       
235        self._first_unsmeared_bin = 0
236        self._last_unsmeared_bin  = len(self.x) - 1
237       
238        if self.smearer != None:
239            self._first_unsmeared_bin, self._last_unsmeared_bin = \
240                    self.smearer.get_bin_range(self.qmin, self.qmax)
241            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
242            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
243           
244        # Identify the bin range for the unsmeared and smeared spaces
245        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
246        ## zero error can not participate for fitting
247        self.idx = self.idx & (self.dy != 0) 
248        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
249                            & (self.x <= self._qmax_unsmeared)
250 
251    def get_fit_range(self):
252        """
253        return the range of data.x to fit
254        """
255        return self.qmin, self.qmax
256       
257    def residuals(self, fn):
258        """
259        Compute residuals.
260       
261        If self.smearer has been set, use if to smear
262        the data before computing chi squared.
263       
264        :param fn: function that return model value
265       
266        :return: residuals
267       
268        """
269        # Compute theory data f(x)
270        fx = numpy.zeros(len(self.x))
271        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
272       
273        ## Smear theory data
274        if self.smearer is not None:
275            fx = self.smearer(fx, self._first_unsmeared_bin, 
276                              self._last_unsmeared_bin)
277        ## Sanity check
278        if numpy.size(self.dy) != numpy.size(fx):
279            msg = "FitData1D: invalid error array "
280            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))                                                     
281            raise RuntimeError, msg 
282        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx]
283     
284    def residuals_deriv(self, model, pars=[]):
285        """
286        :return: residuals derivatives .
287       
288        :note: in this case just return empty array
289       
290        """
291        return []
292   
293class FitData2D(Data2D):
294    """ Wrapper class  for SANS data """
295    def __init__(self, sans_data2d, data=None, err_data=None):
296        Data2D.__init__(self, data=data, err_data=err_data)
297        """
298        Data can be initital with a data (sans plottable)
299        or with vectors.
300        """
301        self.res_err_image = []
302        self.index_model = []
303        self.qmin = None
304        self.qmax = None
305        self.smearer = None
306        self.radius = 0
307        self.res_err_data = []
308        self.set_data(sans_data2d)
309
310    def set_data(self, sans_data2d, qmin=None, qmax=None):
311        """
312        Determine the correct qx_data and qy_data within range to fit
313        """
314        self.data = sans_data2d.data
315        self.err_data = sans_data2d.err_data
316        self.qx_data = sans_data2d.qx_data
317        self.qy_data = sans_data2d.qy_data
318        self.mask = sans_data2d.mask
319
320        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
321        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
322       
323        ## fitting range
324        if qmin == None:
325            self.qmin = 1e-16
326        if qmax == None:
327            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
328        ## new error image for fitting purpose
329        if self.err_data == None or self.err_data == []:
330            self.res_err_data = numpy.ones(len(self.data))
331        else:
332            self.res_err_data = copy.deepcopy(self.err_data)
333        #self.res_err_data[self.res_err_data==0]=1
334       
335        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
336       
337        # Note: mask = True: for MASK while mask = False for NOT to mask
338        self.index_model = ((self.qmin <= self.radius)&\
339                            (self.radius <= self.qmax))
340        self.index_model = (self.index_model) & (self.mask)
341        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
342       
343    def set_smearer(self, smearer): 
344        """
345        Set smearer
346        """
347        if smearer == None:
348            return
349        self.smearer = smearer
350        self.smearer.set_index(self.index_model)
351        self.smearer.get_data()
352
353    def set_fit_range(self, qmin=None, qmax=None):
354        """ to set the fit range"""
355        if qmin == 0.0:
356            self.qmin = 1e-16
357        elif qmin != None:                       
358            self.qmin = qmin           
359        if qmax != None:
360            self.qmax = qmax       
361        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
362        self.index_model = ((self.qmin <= self.radius)&\
363                            (self.radius <= self.qmax))
364        self.index_model = (self.index_model) &(self.mask)
365        self.index_model = (self.index_model) & (numpy.isfinite(self.data))
366        self.index_model = (self.index_model) & (self.res_err_data != 0)
367       
368    def get_fit_range(self):
369        """
370        return the range of data.x to fit
371        """
372        return self.qmin, self.qmax
373     
374    def residuals(self, fn): 
375        """
376        return the residuals
377        """ 
378        if self.smearer != None:
379            fn.set_index(self.index_model)
380            # Get necessary data from self.data and set the data for smearing
381            fn.get_data()
382
383            gn = fn.get_value() 
384        else:
385            gn = fn([self.qx_data[self.index_model],
386                     self.qy_data[self.index_model]])
387        # use only the data point within ROI range
388        res = (self.data[self.index_model] - gn)/\
389                    self.res_err_data[self.index_model]
390        return res
391       
392    def residuals_deriv(self, model, pars=[]):
393        """
394        :return: residuals derivatives .
395       
396        :note: in this case just return empty array
397       
398        """
399        return []
400   
401class FitAbort(Exception):
402    """
403    Exception raise to stop the fit
404    """
405
406
407class SansAssembly:
408    """
409    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
410    """
411    def __init__(self, paramlist, model=None , data=None, fitresult=None,
412                 handler=None, curr_thread=None):
413        """
414        :param Model: the model wrapper fro sans -model
415        :param Data: the data wrapper for sans data
416       
417        """
418        self.model = model
419        self.data  = data
420        self.paramlist = paramlist
421        self.curr_thread = curr_thread
422        self.handler = handler
423        self.fitresult = fitresult
424        self.res = []
425        self.true_res = []
426        self.func_name = "Functor"
427       
428    #def chisq(self, params):
429    def chisq(self):
430        """
431        Calculates chi^2
432       
433        :param params: list of parameter values
434       
435        :return: chi^2
436       
437        """
438        sum = 0
439        for item in self.true_res:
440            sum += item * item
441        if len(self.true_res) == 0:
442            return None
443        return sum / len(self.true_res)
444   
445    def __call__(self, params):
446        """
447        Compute residuals
448       
449        :param params: value of parameters to fit
450       
451        """ 
452        #import thread
453        self.model.set_params(self.paramlist, params)
454        self.true_res = self.data.residuals(self.model.eval)
455        # check parameters range
456        if self.check_param_range():
457            # if the param value is outside of the bound
458            # just silent return res = inf
459            return self.res
460        self.res = self.true_res       
461        if self.fitresult is not None and  self.handler is not None:
462            self.fitresult.set_model(model=self.model)
463            #fitness = self.chisq(params=params)
464            fitness = self.chisq()
465            self.fitresult.pvec = params
466            self.fitresult.set_fitness(fitness=fitness)
467            self.handler.set_result(result=self.fitresult)
468            self.handler.update_fit()
469
470            if self.curr_thread != None :
471                try:
472                    self.curr_thread.isquit()
473                except:
474                   self.handler.error("Terminating Fitting;")
475                   raise FitAbort,"The LeastSqr Fit: terminated by the user."   
476         
477        return self.res
478   
479    def check_param_range(self):
480        """
481        Check the lower and upper bound of the parameter value
482        and set res to the inf if the value is outside of the
483        range
484        :limitation: the initial values must be within range.
485        """
486
487        time.sleep(0.01)
488        is_outofbound = False
489        # loop through the fit parameters
490        for p in self.model.parameterset:
491            param_name = p.get_name()
492            if param_name in self.paramlist:
493               
494                # if the range was defined, check the range
495                if numpy.isfinite(p.range[0]):
496                    if p.value == 0:
497                        # This value works on Scipy
498                        # Do not change numbers below
499                        value = _SMALLVALUE
500                    else:
501                        value = p.value
502                    # For leastsq, it needs a bit step back from the boundary
503                    val = p.range[0] - value * _SMALLVALUE
504                    if p.value < val: 
505                        self.res *= 1e+6
506                       
507                        is_outofbound = True
508                        break
509                if numpy.isfinite(p.range[1]):
510                    # This value works on Scipy
511                    # Do not change numbers below
512                    if p.value == 0:
513                        value = _SMALLVALUE
514                    else:
515                        value = p.value
516                    # For leastsq, it needs a bit step back from the boundary
517                    val = p.range[1] + value * _SMALLVALUE
518                    if p.value > val:
519                        self.res *= 1e+6
520                        is_outofbound = True
521                        break
522
523        return is_outofbound
524   
525   
526class FitEngine:
527    def __init__(self):
528        """
529        Base class for scipy and park fit engine
530        """
531        #List of parameter names to fit
532        self.param_list = []
533        #Dictionnary of fitArrange element (fit problems)
534        self.fit_arrange_dict = {}
535 
536    def set_model(self, model, id, pars=[], constraints=[]):
537        """
538        set a model on a given  in the fit engine.
539       
540        :param model: sans.models type
541        :param : is the key of the fitArrange dictionary where model is
542                saved as a value
543        :param pars: the list of parameters to fit
544        :param constraints: list of
545            tuple (name of parameter, value of parameters)
546            the value of parameter must be a string to constraint 2 different
547            parameters.
548            Example: 
549            we want to fit 2 model M1 and M2 both have parameters A and B.
550            constraints can be:
551             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
552           
553             
554        :note: pars must contains only name of existing model's parameters
555       
556        """
557        if model == None:
558            raise ValueError, "AbstractFitEngine: Need to set model to fit"
559       
560        new_model = model
561        if not issubclass(model.__class__, Model):
562            new_model = Model(model)
563       
564        if len(constraints) > 0:
565            for constraint in constraints:
566                name, value = constraint
567                try:
568                    new_model.parameterset[str(name)].set(str(value))
569                except:
570                    msg = "Fit Engine: Error occurs when setting the constraint"
571                    msg += " %s for parameter %s " % (value, name)
572                    raise ValueError, msg
573               
574        if len(pars) > 0:
575            temp = []
576            for item in pars:
577                if item in new_model.model.getParamList():
578                    temp.append(item)
579                    self.param_list.append(item)
580                else:
581                   
582                    msg = "wrong parameter %s used" % str(item)
583                    msg += "to set model %s. Choose" % str(new_model.model.name)
584                    msg += "parameter name within %s" % \
585                                str(new_model.model.getParamList())
586                    raise ValueError, msg
587             
588            #A fitArrange is already created but contains data_list only at id
589            if self.fit_arrange_dict.has_key(id):
590                self.fit_arrange_dict[id].set_model(new_model)
591                self.fit_arrange_dict[id].pars = pars
592            else:
593            #no fitArrange object has been create with this id
594                fitproblem = FitArrange()
595                fitproblem.set_model(new_model)
596                fitproblem.pars = pars
597                self.fit_arrange_dict[id] = fitproblem
598               
599        else:
600            raise ValueError, "park_integration:missing parameters"
601   
602    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
603        """
604        Receives plottable, creates a list of data to fit,set data
605        in a FitArrange object and adds that object in a dictionary
606        with key id.
607       
608        :param data: data added
609        :param id: unique key corresponding to a fitArrange object with data
610       
611        """
612        if data.__class__.__name__ == 'Data2D':
613            fitdata = FitData2D(sans_data2d=data, data=data.data,
614                                 err_data=data.err_data)
615        else:
616            fitdata = FitData1D(x=data.x, y=data.y ,
617                                 dx=data.dx, dy=data.dy, smearer=smearer)
618       
619        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
620        #A fitArrange is already created but contains model only at id
621        if self.fit_arrange_dict.has_key(id):
622            self.fit_arrange_dict[id].add_data(fitdata)
623        else:
624        #no fitArrange object has been create with this id
625            fitproblem = FitArrange()
626            fitproblem.add_data(fitdata)
627            self.fit_arrange_dict[id] = fitproblem   
628   
629    def get_model(self, id):
630        """
631       
632        :param id: id is key in the dictionary containing the model to return
633       
634        :return:  a model at this id or None if no FitArrange element was
635            created with this id
636           
637        """
638        if self.fit_arrange_dict.has_key(id):
639            return self.fit_arrange_dict[id].get_model()
640        else:
641            return None
642   
643    def remove_fit_problem(self, id):
644        """remove   fitarrange in id"""
645        if self.fit_arrange_dict.has_key(id):
646            del self.fit_arrange_dict[id]
647           
648    def select_problem_for_fit(self, id, value):
649        """
650        select a couple of model and data at the id position in dictionary
651        and set in self.selected value to value
652       
653        :param value: the value to allow fitting.
654                can only have the value one or zero
655               
656        """
657        if self.fit_arrange_dict.has_key(id):
658            self.fit_arrange_dict[id].set_to_fit(value)
659             
660    def get_problem_to_fit(self, id):
661        """
662        return the self.selected value of the fit problem of id
663       
664        :param id: the id of the problem
665       
666        """
667        if self.fit_arrange_dict.has_key(id):
668            self.fit_arrange_dict[id].get_to_fit()
669   
670class FitArrange:
671    def __init__(self):
672        """
673        Class FitArrange contains a set of data for a given model
674        to perform the Fit.FitArrange must contain exactly one model
675        and at least one data for the fit to be performed.
676       
677        model: the model selected by the user
678        Ldata: a list of data what the user wants to fit
679           
680        """
681        self.model = None
682        self.data_list = []
683        self.pars = []
684        #self.selected  is zero when this fit problem is not schedule to fit
685        #self.selected is 1 when schedule to fit
686        self.selected = 0
687       
688    def set_model(self, model):
689        """
690        set_model save a copy of the model
691       
692        :param model: the model being set
693       
694        """
695        self.model = model
696       
697    def add_data(self, data):
698        """
699        add_data fill a self.data_list with data to fit
700       
701        :param data: Data to add in the list 
702       
703        """
704        if not data in self.data_list:
705            self.data_list.append(data)
706           
707    def get_model(self):
708        """
709       
710        :return: saved model
711       
712        """
713        return self.model   
714     
715    def get_data(self):
716        """
717       
718        :return: list of data data_list
719       
720        """
721        #return self.data_list
722        return self.data_list[0] 
723     
724    def remove_data(self, data):
725        """
726        Remove one element from the list
727       
728        :param data: Data to remove from data_list
729       
730        """
731        if data in self.data_list:
732            self.data_list.remove(data)
733           
734    def set_to_fit (self, value=0):
735        """
736        set self.selected to 0 or 1  for other values raise an exception
737       
738        :param value: integer between 0 or 1
739       
740        """
741        self.selected = value
742       
743    def get_to_fit(self):
744        """
745        return self.selected value
746        """
747        return self.selected
Note: See TracBrowser for help on using the repository browser.