source: sasview/src/sans/fit/AbstractFitEngine.py @ 5bf0331

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 5bf0331 was 95d58d3, checked in by pkienzle, 10 years ago

fix fit line test for bumps/scipy/park and enable it as part of test suite

  • Property mode set to 100644
File size: 21.4 KB
RevLine 
[51f14603]1
2import  copy
3#import logging
4import sys
5import math
[6fe5100]6import numpy
7
[51f14603]8from sans.dataloader.data_info import Data1D
9from sans.dataloader.data_info import Data2D
[6fe5100]10_SMALLVALUE = 1.0e-10
11
12# Note: duplicated from park
13class FitHandler(object):
[51f14603]14    """
[6fe5100]15    Abstract interface for fit thread handler.
16
17    The methods in this class are called by the optimizer as the fit
18    progresses.
19
20    Note that it is up to the optimizer to call the fit handler correctly,
21    reporting all status changes and maintaining the 'done' flag.
[51f14603]22    """
[6fe5100]23    done = False
24    """True when the fit job is complete"""
25    result = None
26    """The current best result of the fit"""
27
28    def improvement(self):
[51f14603]29        """
[6fe5100]30        Called when a result is observed which is better than previous
31        results from the fit.
32
33        result is a FitResult object, with parameters, #calls and fitness.
[51f14603]34        """
[6fe5100]35    def error(self, msg):
[51f14603]36        """
[6fe5100]37        Model had an error; print traceback
[51f14603]38        """
[6fe5100]39    def progress(self, current, expected):
[51f14603]40        """
[6fe5100]41        Called each cycle of the fit, reporting the current and the
42        expected amount of work.   The meaning of these values is
43        optimizer dependent, but they can be converted into a percent
44        complete using (100*current)//expected.
45
46        Progress is updated each iteration of the fit, whatever that
47        means for the particular optimization algorithm.  It is called
48        after any calls to improvement for the iteration so that the
49        update handler can control I/O bandwidth by suppressing
50        intermediate improvements until the fit is complete.
[51f14603]51        """
[6fe5100]52    def finalize(self):
[51f14603]53        """
[6fe5100]54        Fit is complete; best results are reported
[51f14603]55        """
[6fe5100]56    def abort(self):
[51f14603]57        """
[6fe5100]58        Fit was aborted.
[51f14603]59        """
[6fe5100]60
[95d58d3]61    # TODO: not sure how these are used, but they are needed for running the fit
62    def update_fit(self, last=False): pass
63    def set_result(self, result=None): self.result = result
64
[6fe5100]65class Model:
[51f14603]66    """
[6fe5100]67    Fit wrapper for SANS models.
[51f14603]68    """
69    def __init__(self, sans_model, sans_data=None, **kw):
70        """
71        :param sans_model: the sans model to wrap using park interface
[6fe5100]72
[51f14603]73        """
74        self.model = sans_model
75        self.name = sans_model.name
76        self.data = sans_data
[6fe5100]77
[51f14603]78    def get_params(self, fitparams):
79        """
80        return a list of value of paramter to fit
[6fe5100]81
[51f14603]82        :param fitparams: list of paramaters name to fit
[6fe5100]83
[51f14603]84        """
[6fe5100]85        return [self.model.getParam(k) for k in fitparams]
86
[51f14603]87    def set_params(self, paramlist, params):
88        """
89        Set value for parameters to fit
[6fe5100]90
[51f14603]91        :param params: list of value for parameters to fit
[6fe5100]92
[51f14603]93        """
[6fe5100]94        for k,v in zip(paramlist, params):
95            self.model.setParam(k,v)
96
97    def set(self, **kw):
98        self.set_params(*zip(*kw.items()))
99
[51f14603]100    def eval(self, x):
101        """
102            Override eval method of park model.
[6fe5100]103
[51f14603]104            :param x: the x value used to compute a function
105        """
106        try:
107            return self.model.evalDistribution(x)
108        except:
109            raise
[6fe5100]110
[51f14603]111    def eval_derivs(self, x, pars=[]):
112        """
113        Evaluate the model and derivatives wrt pars at x.
114
115        pars is a list of the names of the parameters for which derivatives
116        are desired.
117
118        This method needs to be specialized in the model to evaluate the
119        model function.  Alternatively, the model can implement is own
120        version of residuals which calculates the residuals directly
121        instead of calling eval.
122        """
[6fe5100]123        raise NotImplementedError('no derivatives available')
124
125    def __call__(self, x):
126        return self.eval(x)
[51f14603]127
128class FitData1D(Data1D):
129    """
130        Wrapper class  for SANS data
131        FitData1D inherits from DataLoader.data_info.Data1D. Implements
132        a way to get residuals from data.
133    """
134    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None):
135        """
136            :param smearer: is an object of class QSmearer or SlitSmearer
137               that will smear the theory data (slit smearing or resolution
138               smearing) when set.
139           
140            The proper way to set the smearing object would be to
141            do the following: ::
142           
[6c00702]143                from sans.models.qsmearing import smear_selection
[51f14603]144                smearer = smear_selection(some_data)
145                fitdata1d = FitData1D( x= [1,3,..,],
146                                        y= [3,4,..,8],
147                                        dx=None,
148                                        dy=[1,2...], smearer= smearer)
149           
150            :Note: that some_data _HAS_ to be of
151                class DataLoader.data_info.Data1D
152                Setting it back to None will turn smearing off.
153               
154        """
155        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
[6fe5100]156        self.num_points = len(x)
[51f14603]157        self.sans_data = data
158        self.smearer = smearer
159        self._first_unsmeared_bin = None
160        self._last_unsmeared_bin = None
161        # Check error bar; if no error bar found, set it constant(=1)
162        # TODO: Should provide an option for users to set it like percent,
163        # constant, or dy data
164        if dy == None or dy == [] or dy.all() == 0:
165            self.dy = numpy.ones(len(y))
166        else:
167            self.dy = numpy.asarray(dy).copy()
168
169        ## Min Q-value
170        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
171        if min(self.x) == 0.0 and self.x[0] == 0 and\
172                     not numpy.isfinite(self.y[0]):
173            self.qmin = min(self.x[self.x != 0])
174        else:
175            self.qmin = min(self.x)
176        ## Max Q-value
177        self.qmax = max(self.x)
178       
179        # Range used for input to smearing
180        self._qmin_unsmeared = self.qmin
181        self._qmax_unsmeared = self.qmax
182        # Identify the bin range for the unsmeared and smeared spaces
183        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
184        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
185                            & (self.x <= self._qmax_unsmeared)
186 
187    def set_fit_range(self, qmin=None, qmax=None):
188        """ to set the fit range"""
189        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
190        # ToDo: Find better way to do it.
191        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
192            self.qmin = min(self.x[self.x != 0])
193        elif qmin != None:
194            self.qmin = qmin
195        if qmax != None:
196            self.qmax = qmax
197        # Determine the range needed in unsmeared-Q to cover
198        # the smeared Q range
199        self._qmin_unsmeared = self.qmin
200        self._qmax_unsmeared = self.qmax
201       
202        self._first_unsmeared_bin = 0
203        self._last_unsmeared_bin = len(self.x) - 1
204       
205        if self.smearer != None:
206            self._first_unsmeared_bin, self._last_unsmeared_bin = \
207                    self.smearer.get_bin_range(self.qmin, self.qmax)
208            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
209            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
210           
211        # Identify the bin range for the unsmeared and smeared spaces
212        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
213        ## zero error can not participate for fitting
214        self.idx = self.idx & (self.dy != 0)
215        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
216                            & (self.x <= self._qmax_unsmeared)
217
218    def get_fit_range(self):
219        """
220            Return the range of data.x to fit
221        """
222        return self.qmin, self.qmax
[95d58d3]223
224    def size(self):
225        """
226        Number of measurement points in data set after masking, etc.
227        """
228        return len(self.x)
229
[51f14603]230    def residuals(self, fn):
231        """
232            Compute residuals.
233           
234            If self.smearer has been set, use if to smear
235            the data before computing chi squared.
236           
237            :param fn: function that return model value
238           
239            :return: residuals
240        """
241        # Compute theory data f(x)
242        fx = numpy.zeros(len(self.x))
243        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
244       
245        ## Smear theory data
246        if self.smearer is not None:
247            fx = self.smearer(fx, self._first_unsmeared_bin,
248                              self._last_unsmeared_bin)
249        ## Sanity check
250        if numpy.size(self.dy) != numpy.size(fx):
251            msg = "FitData1D: invalid error array "
252            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))
253            raise RuntimeError, msg
254        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
255           
256    def residuals_deriv(self, model, pars=[]):
257        """
258            :return: residuals derivatives .
259           
260            :note: in this case just return empty array
261        """
262        return []
263   
264   
265class FitData2D(Data2D):
266    """
267        Wrapper class  for SANS data
268    """
269    def __init__(self, sans_data2d, data=None, err_data=None):
270        Data2D.__init__(self, data=data, err_data=err_data)
[95d58d3]271        # Data can be initialized with a sans plottable or with vectors.
[51f14603]272        self.res_err_image = []
[95d58d3]273        self.num_points = 0 # will be set by set_data
[51f14603]274        self.idx = []
275        self.qmin = None
276        self.qmax = None
277        self.smearer = None
278        self.radius = 0
279        self.res_err_data = []
280        self.sans_data = sans_data2d
281        self.set_data(sans_data2d)
282
283    def set_data(self, sans_data2d, qmin=None, qmax=None):
284        """
285            Determine the correct qx_data and qy_data within range to fit
286        """
287        self.data = sans_data2d.data
288        self.err_data = sans_data2d.err_data
289        self.qx_data = sans_data2d.qx_data
290        self.qy_data = sans_data2d.qy_data
291        self.mask = sans_data2d.mask
292
293        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
294        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
295       
296        ## fitting range
297        if qmin == None:
298            self.qmin = 1e-16
299        if qmax == None:
300            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
301        ## new error image for fitting purpose
302        if self.err_data == None or self.err_data == []:
303            self.res_err_data = numpy.ones(len(self.data))
304        else:
305            self.res_err_data = copy.deepcopy(self.err_data)
306        #self.res_err_data[self.res_err_data==0]=1
307       
308        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
309       
310        # Note: mask = True: for MASK while mask = False for NOT to mask
311        self.idx = ((self.qmin <= self.radius) &\
312                            (self.radius <= self.qmax))
313        self.idx = (self.idx) & (self.mask)
314        self.idx = (self.idx) & (numpy.isfinite(self.data))
[95d58d3]315        self.num_points = numpy.sum(self.idx)
[51f14603]316
317    def set_smearer(self, smearer):
318        """
319            Set smearer
320        """
321        if smearer == None:
322            return
323        self.smearer = smearer
324        self.smearer.set_index(self.idx)
325        self.smearer.get_data()
326
327    def set_fit_range(self, qmin=None, qmax=None):
328        """
329            To set the fit range
330        """
331        if qmin == 0.0:
332            self.qmin = 1e-16
333        elif qmin != None:
334            self.qmin = qmin
335        if qmax != None:
336            self.qmax = qmax
337        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
338        self.idx = ((self.qmin <= self.radius) &\
339                            (self.radius <= self.qmax))
340        self.idx = (self.idx) & (self.mask)
341        self.idx = (self.idx) & (numpy.isfinite(self.data))
342        self.idx = (self.idx) & (self.res_err_data != 0)
343
344    def get_fit_range(self):
345        """
346        return the range of data.x to fit
347        """
348        return self.qmin, self.qmax
[95d58d3]349
350    def size(self):
351        """
352        Number of measurement points in data set after masking, etc.
353        """
354        return numpy.sum(self.idx)
355
[51f14603]356    def residuals(self, fn):
357        """
358        return the residuals
359        """
360        if self.smearer != None:
361            fn.set_index(self.idx)
362            # Get necessary data from self.data and set the data for smearing
363            fn.get_data()
364
365            gn = fn.get_value()
366        else:
367            gn = fn([self.qx_data[self.idx],
368                     self.qy_data[self.idx]])
369        # use only the data point within ROI range
370        res = (self.data[self.idx] - gn) / self.res_err_data[self.idx]
371
372        return res, gn
373       
374    def residuals_deriv(self, model, pars=[]):
375        """
376        :return: residuals derivatives .
377       
378        :note: in this case just return empty array
379       
380        """
381        return []
382   
383   
384class FitAbort(Exception):
385    """
386    Exception raise to stop the fit
387    """
388    #pass
389    #print"Creating fit abort Exception"
390
391
392
393class FitEngine:
394    def __init__(self):
395        """
396        Base class for scipy and park fit engine
397        """
398        #List of parameter names to fit
399        self.param_list = []
400        #Dictionnary of fitArrange element (fit problems)
401        self.fit_arrange_dict = {}
402        self.fitter_id = None
403       
404    def set_model(self, model, id, pars=[], constraints=[], data=None):
405        """
406        set a model on a given  in the fit engine.
407       
408        :param model: sans.models type
409        :param id: is the key of the fitArrange dictionary where model is saved as a value
410        :param pars: the list of parameters to fit
411        :param constraints: list of
412            tuple (name of parameter, value of parameters)
413            the value of parameter must be a string to constraint 2 different
414            parameters.
415            Example: 
416            we want to fit 2 model M1 and M2 both have parameters A and B.
417            constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]``
418           
419             
420        :note: pars must contains only name of existing model's parameters
421       
422        """
423        if model == None:
424            raise ValueError, "AbstractFitEngine: Need to set model to fit"
425       
426        if not issubclass(model.__class__, Model):
[95d58d3]427            model = Model(model, data)
428
429        sasmodel = model.model
[51f14603]430        if len(constraints) > 0:
431            for constraint in constraints:
432                name, value = constraint
433                try:
[95d58d3]434                    model.parameterset[str(name)].set(str(value))
[51f14603]435                except:
436                    msg = "Fit Engine: Error occurs when setting the constraint"
437                    msg += " %s for parameter %s " % (value, name)
438                    raise ValueError, msg
439               
440        if len(pars) > 0:
441            temp = []
442            for item in pars:
[95d58d3]443                if item in sasmodel.getParamList():
[51f14603]444                    temp.append(item)
445                    self.param_list.append(item)
446                else:
447                   
[f87dc4c]448                    msg = "wrong parameter %s used " % str(item)
[95d58d3]449                    msg += "to set model %s. Choose " % str(sasmodel.name)
[51f14603]450                    msg += "parameter name within %s" % \
[95d58d3]451                                str(sasmodel.getParamList())
[51f14603]452                    raise ValueError, msg
453             
454            #A fitArrange is already created but contains data_list only at id
455            if self.fit_arrange_dict.has_key(id):
[95d58d3]456                self.fit_arrange_dict[id].set_model(model)
[51f14603]457                self.fit_arrange_dict[id].pars = pars
458            else:
459            #no fitArrange object has been create with this id
460                fitproblem = FitArrange()
[95d58d3]461                fitproblem.set_model(model)
[51f14603]462                fitproblem.pars = pars
463                self.fit_arrange_dict[id] = fitproblem
464                vals = []
465                for name in pars:
[95d58d3]466                    vals.append(sasmodel.getParam(name))
[51f14603]467                self.fit_arrange_dict[id].vals = vals
468        else:
469            raise ValueError, "park_integration:missing parameters"
470   
471    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
472        """
473        Receives plottable, creates a list of data to fit,set data
474        in a FitArrange object and adds that object in a dictionary
475        with key id.
476       
477        :param data: data added
478        :param id: unique key corresponding to a fitArrange object with data
479        """
480        if data.__class__.__name__ == 'Data2D':
481            fitdata = FitData2D(sans_data2d=data, data=data.data,
482                                 err_data=data.err_data)
483        else:
484            fitdata = FitData1D(x=data.x, y=data.y,
485                                 dx=data.dx, dy=data.dy, smearer=smearer)
486        fitdata.sans_data = data
487       
488        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
489        #A fitArrange is already created but contains model only at id
490        if id in self.fit_arrange_dict:
491            self.fit_arrange_dict[id].add_data(fitdata)
492        else:
493        #no fitArrange object has been create with this id
494            fitproblem = FitArrange()
495            fitproblem.add_data(fitdata)
496            self.fit_arrange_dict[id] = fitproblem
497   
498    def get_model(self, id):
499        """
500        :param id: id is key in the dictionary containing the model to return
501       
502        :return:  a model at this id or None if no FitArrange element was
503            created with this id
504        """
505        if id in self.fit_arrange_dict:
506            return self.fit_arrange_dict[id].get_model()
507        else:
508            return None
509   
510    def remove_fit_problem(self, id):
511        """remove   fitarrange in id"""
512        if id in self.fit_arrange_dict:
513            del self.fit_arrange_dict[id]
514           
515    def select_problem_for_fit(self, id, value):
516        """
517        select a couple of model and data at the id position in dictionary
518        and set in self.selected value to value
519       
520        :param value: the value to allow fitting.
521                can only have the value one or zero
522        """
523        if id in self.fit_arrange_dict:
524            self.fit_arrange_dict[id].set_to_fit(value)
525             
526    def get_problem_to_fit(self, id):
527        """
528        return the self.selected value of the fit problem of id
529       
530        :param id: the id of the problem
531        """
532        if id in self.fit_arrange_dict:
533            self.fit_arrange_dict[id].get_to_fit()
534   
535   
536class FitArrange:
537    def __init__(self):
538        """
539        Class FitArrange contains a set of data for a given model
540        to perform the Fit.FitArrange must contain exactly one model
541        and at least one data for the fit to be performed.
542       
543        model: the model selected by the user
544        Ldata: a list of data what the user wants to fit
545           
546        """
547        self.model = None
548        self.data_list = []
549        self.pars = []
550        self.vals = []
551        self.selected = 0
552       
553    def set_model(self, model):
554        """
555        set_model save a copy of the model
556       
557        :param model: the model being set
558        """
559        self.model = model
560       
561    def add_data(self, data):
562        """
563        add_data fill a self.data_list with data to fit
564       
565        :param data: Data to add in the list
566        """
567        if not data in self.data_list:
568            self.data_list.append(data)
569           
570    def get_model(self):
571        """
572        :return: saved model
573        """
574        return self.model
575     
576    def get_data(self):
577        """
578        :return: list of data data_list
579        """
580        return self.data_list[0]
581     
582    def remove_data(self, data):
583        """
584        Remove one element from the list
585       
586        :param data: Data to remove from data_list
587        """
588        if data in self.data_list:
589            self.data_list.remove(data)
590           
591    def set_to_fit(self, value=0):
592        """
593        set self.selected to 0 or 1  for other values raise an exception
594       
595        :param value: integer between 0 or 1
596        """
597        self.selected = value
598       
599    def get_to_fit(self):
600        """
601        return self.selected value
602        """
603        return self.selected
604   
605   
606class FResult(object):
607    """
608    Storing fit result
609    """
610    def __init__(self, model=None, param_list=None, data=None):
611        self.calls = None
612        self.pars = []
613        self.fitness = None
614        self.chisqr = None
615        self.pvec = []
616        self.cov = []
617        self.info = None
618        self.mesg = None
619        self.success = None
620        self.stderr = None
621        self.residuals = []
622        self.index = []
623        self.parameters = None
624        self.model = model
625        self.data = data
626        self.theory = []
627        self.param_list = param_list
628        self.iterations = 0
629        self.inputs = []
630        self.fitter_id = None
631        if self.model is not None and self.data is not None:
632            self.inputs = [(self.model, self.data)]
633     
634    def set_model(self, model):
635        """
636        """
637        self.model = model
638       
639    def set_fitness(self, fitness):
640        """
641        """
642        self.fitness = fitness
643       
644    def __str__(self):
645        """
646        """
647        if self.pvec == None and self.model is None and self.param_list is None:
648            return "No results"
[6fe5100]649
[95d58d3]650        sasmodel = self.model.model
651        pars = enumerate(sasmodel.getParamList())
[6fe5100]652        msg1 = "[Iteration #: %s ]" % self.iterations
653        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness))
[95d58d3]654        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v))
[6fe5100]655                for i,v in pars if v in self.param_list]
656        msg = [msg1, msg3] + msg2
657        return "\n".join(msg)
[51f14603]658   
659    def print_summary(self):
660        """
661        """
[95d58d3]662        print str(self)
Note: See TracBrowser for help on using the repository browser.