source: sasview/src/sas/sascalc/pr/fit/AbstractFitEngine.py @ 4c930a2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 4c930a2 was fc18690, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 9 years ago

Sasmodels integration - moved smearing from models to sascalc.

  • Property mode set to 100644
File size: 20.2 KB
Line 
1
2import  copy
3#import logging
4import sys
5import math
6import numpy
7
8from sas.sascalc.dataloader.data_info import Data1D
9from sas.sascalc.dataloader.data_info import Data2D
10_SMALLVALUE = 1.0e-10
11
12class FitHandler(object):
13    """
14    Abstract interface for fit thread handler.
15
16    The methods in this class are called by the optimizer as the fit
17    progresses.
18
19    Note that it is up to the optimizer to call the fit handler correctly,
20    reporting all status changes and maintaining the 'done' flag.
21    """
22    done = False
23    """True when the fit job is complete"""
24    result = None
25    """The current best result of the fit"""
26
27    def improvement(self):
28        """
29        Called when a result is observed which is better than previous
30        results from the fit.
31
32        result is a FitResult object, with parameters, #calls and fitness.
33        """
34    def error(self, msg):
35        """
36        Model had an error; print traceback
37        """
38    def progress(self, current, expected):
39        """
40        Called each cycle of the fit, reporting the current and the
41        expected amount of work.   The meaning of these values is
42        optimizer dependent, but they can be converted into a percent
43        complete using (100*current)//expected.
44
45        Progress is updated each iteration of the fit, whatever that
46        means for the particular optimization algorithm.  It is called
47        after any calls to improvement for the iteration so that the
48        update handler can control I/O bandwidth by suppressing
49        intermediate improvements until the fit is complete.
50        """
51    def finalize(self):
52        """
53        Fit is complete; best results are reported
54        """
55    def abort(self):
56        """
57        Fit was aborted.
58        """
59
60    # TODO: not sure how these are used, but they are needed for running the fit
61    def update_fit(self, last=False): pass
62    def set_result(self, result=None): self.result = result
63
64class Model:
65    """
66    Fit wrapper for SAS models.
67    """
68    def __init__(self, sas_model, sas_data=None, **kw):
69        """
70        :param sas_model: the sas model to wrap for fitting
71
72        """
73        self.model = sas_model
74        self.name = sas_model.name
75        self.data = sas_data
76
77    def get_params(self, fitparams):
78        """
79        return a list of value of paramter to fit
80
81        :param fitparams: list of paramaters name to fit
82
83        """
84        return [self.model.getParam(k) for k in fitparams]
85
86    def set_params(self, paramlist, params):
87        """
88        Set value for parameters to fit
89
90        :param params: list of value for parameters to fit
91
92        """
93        for k,v in zip(paramlist, params):
94            self.model.setParam(k,v)
95
96    def set(self, **kw):
97        self.set_params(*zip(*kw.items()))
98
99    def eval(self, x):
100        """
101            Override eval method of model.
102
103            :param x: the x value used to compute a function
104        """
105        try:
106            return self.model.evalDistribution(x)
107        except:
108            raise
109
110    def eval_derivs(self, x, pars=[]):
111        """
112        Evaluate the model and derivatives wrt pars at x.
113
114        pars is a list of the names of the parameters for which derivatives
115        are desired.
116
117        This method needs to be specialized in the model to evaluate the
118        model function.  Alternatively, the model can implement is own
119        version of residuals which calculates the residuals directly
120        instead of calling eval.
121        """
122        raise NotImplementedError('no derivatives available')
123
124    def __call__(self, x):
125        return self.eval(x)
126
127class FitData1D(Data1D):
128    """
129        Wrapper class  for SAS data
130        FitData1D inherits from DataLoader.data_info.Data1D. Implements
131        a way to get residuals from data.
132    """
133    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None):
134        """
135            :param smearer: is an object of class QSmearer or SlitSmearer
136               that will smear the theory data (slit smearing or resolution
137               smearing) when set.
138           
139            The proper way to set the smearing object would be to
140            do the following: ::
141           
142                from sas.sascalc.data_util.qsmearing import smear_selection
143                smearer = smear_selection(some_data)
144                fitdata1d = FitData1D( x= [1,3,..,],
145                                        y= [3,4,..,8],
146                                        dx=None,
147                                        dy=[1,2...], smearer= smearer)
148           
149            :Note: that some_data _HAS_ to be of
150                class DataLoader.data_info.Data1D
151                Setting it back to None will turn smearing off.
152               
153        """
154        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
155        self.num_points = len(x)
156        self.sas_data = data
157        self.smearer = smearer
158        self._first_unsmeared_bin = None
159        self._last_unsmeared_bin = None
160        # Check error bar; if no error bar found, set it constant(=1)
161        # TODO: Should provide an option for users to set it like percent,
162        # constant, or dy data
163        if dy is None or dy == [] or dy.all() == 0:
164            self.dy = numpy.ones(len(y))
165        else:
166            self.dy = numpy.asarray(dy).copy()
167
168        ## Min Q-value
169        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
170        if min(self.x) == 0.0 and self.x[0] == 0 and\
171                     not numpy.isfinite(self.y[0]):
172            self.qmin = min(self.x[self.x != 0])
173        else:
174            self.qmin = min(self.x)
175        ## Max Q-value
176        self.qmax = max(self.x)
177       
178        # Range used for input to smearing
179        self._qmin_unsmeared = self.qmin
180        self._qmax_unsmeared = self.qmax
181        # Identify the bin range for the unsmeared and smeared spaces
182        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
183        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
184                            & (self.x <= self._qmax_unsmeared)
185 
186    def set_fit_range(self, qmin=None, qmax=None):
187        """ to set the fit range"""
188        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
189        # ToDo: Find better way to do it.
190        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
191            self.qmin = min(self.x[self.x != 0])
192        elif qmin != None:
193            self.qmin = qmin
194        if qmax != None:
195            self.qmax = qmax
196        # Determine the range needed in unsmeared-Q to cover
197        # the smeared Q range
198        self._qmin_unsmeared = self.qmin
199        self._qmax_unsmeared = self.qmax
200       
201        self._first_unsmeared_bin = 0
202        self._last_unsmeared_bin = len(self.x) - 1
203       
204        if self.smearer != None:
205            self._first_unsmeared_bin, self._last_unsmeared_bin = \
206                    self.smearer.get_bin_range(self.qmin, self.qmax)
207            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
208            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
209           
210        # Identify the bin range for the unsmeared and smeared spaces
211        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
212        ## zero error can not participate for fitting
213        self.idx = self.idx & (self.dy != 0)
214        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
215                            & (self.x <= self._qmax_unsmeared)
216
217    def get_fit_range(self):
218        """
219            Return the range of data.x to fit
220        """
221        return self.qmin, self.qmax
222
223    def size(self):
224        """
225        Number of measurement points in data set after masking, etc.
226        """
227        return len(self.x)
228
229    def residuals(self, fn):
230        """
231            Compute residuals.
232           
233            If self.smearer has been set, use if to smear
234            the data before computing chi squared.
235           
236            :param fn: function that return model value
237           
238            :return: residuals
239        """
240        # Compute theory data f(x)
241        fx = numpy.zeros(len(self.x))
242        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
243       
244        ## Smear theory data
245        if self.smearer is not None:
246            fx = self.smearer(fx, self._first_unsmeared_bin,
247                              self._last_unsmeared_bin)
248        ## Sanity check
249        if numpy.size(self.dy) != numpy.size(fx):
250            msg = "FitData1D: invalid error array "
251            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))
252            raise RuntimeError, msg
253        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
254           
255    def residuals_deriv(self, model, pars=[]):
256        """
257            :return: residuals derivatives .
258           
259            :note: in this case just return empty array
260        """
261        return []
262
263
264class FitData2D(Data2D):
265    """
266        Wrapper class  for SAS data
267    """
268    def __init__(self, sas_data2d, data=None, err_data=None):
269        Data2D.__init__(self, data=data, err_data=err_data)
270        # Data can be initialized with a sas plottable or with vectors.
271        self.res_err_image = []
272        self.num_points = 0 # will be set by set_data
273        self.idx = []
274        self.qmin = None
275        self.qmax = None
276        self.smearer = None
277        self.radius = 0
278        self.res_err_data = []
279        self.sas_data = sas_data2d
280        self.set_data(sas_data2d)
281
282    def set_data(self, sas_data2d, qmin=None, qmax=None):
283        """
284            Determine the correct qx_data and qy_data within range to fit
285        """
286        self.data = sas_data2d.data
287        self.err_data = sas_data2d.err_data
288        self.qx_data = sas_data2d.qx_data
289        self.qy_data = sas_data2d.qy_data
290        self.mask = sas_data2d.mask
291
292        x_max = max(math.fabs(sas_data2d.xmin), math.fabs(sas_data2d.xmax))
293        y_max = max(math.fabs(sas_data2d.ymin), math.fabs(sas_data2d.ymax))
294       
295        ## fitting range
296        if qmin == None:
297            self.qmin = 1e-16
298        if qmax == None:
299            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
300        ## new error image for fitting purpose
301        if self.err_data == None or self.err_data == []:
302            self.res_err_data = numpy.ones(len(self.data))
303        else:
304            self.res_err_data = copy.deepcopy(self.err_data)
305        #self.res_err_data[self.res_err_data==0]=1
306       
307        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
308       
309        # Note: mask = True: for MASK while mask = False for NOT to mask
310        self.idx = ((self.qmin <= self.radius) &\
311                            (self.radius <= self.qmax))
312        self.idx = (self.idx) & (self.mask)
313        self.idx = (self.idx) & (numpy.isfinite(self.data))
314        self.num_points = numpy.sum(self.idx)
315
316    def set_smearer(self, smearer):
317        """
318            Set smearer
319        """
320        if smearer == None:
321            return
322        self.smearer = smearer
323        self.smearer.set_index(self.idx)
324        self.smearer.get_data()
325
326    def set_fit_range(self, qmin=None, qmax=None):
327        """
328            To set the fit range
329        """
330        if qmin == 0.0:
331            self.qmin = 1e-16
332        elif qmin != None:
333            self.qmin = qmin
334        if qmax != None:
335            self.qmax = qmax
336        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
337        self.idx = ((self.qmin <= self.radius) &\
338                            (self.radius <= self.qmax))
339        self.idx = (self.idx) & (self.mask)
340        self.idx = (self.idx) & (numpy.isfinite(self.data))
341        self.idx = (self.idx) & (self.res_err_data != 0)
342
343    def get_fit_range(self):
344        """
345        return the range of data.x to fit
346        """
347        return self.qmin, self.qmax
348
349    def size(self):
350        """
351        Number of measurement points in data set after masking, etc.
352        """
353        return numpy.sum(self.idx)
354
355    def residuals(self, fn):
356        """
357        return the residuals
358        """
359        if self.smearer != None:
360            fn.set_index(self.idx)
361            # Get necessary data from self.data and set the data for smearing
362            fn.get_data()
363
364            gn = fn.get_value()
365        else:
366            gn = fn([self.qx_data[self.idx],
367                     self.qy_data[self.idx]])
368        # use only the data point within ROI range
369        res = (self.data[self.idx] - gn) / self.res_err_data[self.idx]
370
371        return res, gn
372       
373    def residuals_deriv(self, model, pars=[]):
374        """
375        :return: residuals derivatives .
376       
377        :note: in this case just return empty array
378       
379        """
380        return []
381   
382   
383class FitAbort(Exception):
384    """
385    Exception raise to stop the fit
386    """
387    #pass
388    #print"Creating fit abort Exception"
389
390
391
392class FitEngine:
393    def __init__(self):
394        """
395        Base class for the fit engine
396        """
397        #Dictionnary of fitArrange element (fit problems)
398        self.fit_arrange_dict = {}
399        self.fitter_id = None
400       
401    def set_model(self, model, id, pars=[], constraints=[], data=None):
402        """
403        set a model on a given  in the fit engine.
404       
405        :param model: sas.models type
406        :param id: is the key of the fitArrange dictionary where model is saved as a value
407        :param pars: the list of parameters to fit
408        :param constraints: list of
409            tuple (name of parameter, value of parameters)
410            the value of parameter must be a string to constraint 2 different
411            parameters.
412            Example: 
413            we want to fit 2 model M1 and M2 both have parameters A and B.
414            constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]``
415           
416             
417        :note: pars must contains only name of existing model's parameters
418       
419        """
420        if not pars:
421            raise ValueError("no fitting parameters")
422
423        if model is None:
424            raise ValueError("no model to fit")
425
426        if not issubclass(model.__class__, Model):
427            model = Model(model, data)
428
429        sasmodel = model.model
430        available_parameters = sasmodel.getParamList()
431        for p in pars:
432            if p not in available_parameters:
433                raise ValueError("parameter %s not available in model %s; use one of [%s] instead"
434                                 %(p, sasmodel.name, ", ".join(available_parameters)))
435
436        if id not in self.fit_arrange_dict:
437            self.fit_arrange_dict[id] = FitArrange()
438
439        self.fit_arrange_dict[id].set_model(model)
440        self.fit_arrange_dict[id].pars = pars
441        self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars]
442        self.fit_arrange_dict[id].constraints = constraints
443
444    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
445        """
446        Receives plottable, creates a list of data to fit,set data
447        in a FitArrange object and adds that object in a dictionary
448        with key id.
449       
450        :param data: data added
451        :param id: unique key corresponding to a fitArrange object with data
452        """
453        if data.__class__.__name__ == 'Data2D':
454            fitdata = FitData2D(sas_data2d=data, data=data.data,
455                                 err_data=data.err_data)
456        else:
457            fitdata = FitData1D(x=data.x, y=data.y,
458                                 dx=data.dx, dy=data.dy, smearer=smearer)
459        fitdata.sas_data = data
460       
461        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
462        #A fitArrange is already created but contains model only at id
463        if id in self.fit_arrange_dict:
464            self.fit_arrange_dict[id].add_data(fitdata)
465        else:
466        #no fitArrange object has been create with this id
467            fitproblem = FitArrange()
468            fitproblem.add_data(fitdata)
469            self.fit_arrange_dict[id] = fitproblem
470   
471    def get_model(self, id):
472        """
473        :param id: id is key in the dictionary containing the model to return
474       
475        :return:  a model at this id or None if no FitArrange element was
476            created with this id
477        """
478        if id in self.fit_arrange_dict:
479            return self.fit_arrange_dict[id].get_model()
480        else:
481            return None
482   
483    def remove_fit_problem(self, id):
484        """remove   fitarrange in id"""
485        if id in self.fit_arrange_dict:
486            del self.fit_arrange_dict[id]
487           
488    def select_problem_for_fit(self, id, value):
489        """
490        select a couple of model and data at the id position in dictionary
491        and set in self.selected value to value
492       
493        :param value: the value to allow fitting.
494                can only have the value one or zero
495        """
496        if id in self.fit_arrange_dict:
497            self.fit_arrange_dict[id].set_to_fit(value)
498             
499    def get_problem_to_fit(self, id):
500        """
501        return the self.selected value of the fit problem of id
502       
503        :param id: the id of the problem
504        """
505        if id in self.fit_arrange_dict:
506            self.fit_arrange_dict[id].get_to_fit()
507   
508   
509class FitArrange:
510    def __init__(self):
511        """
512        Class FitArrange contains a set of data for a given model
513        to perform the Fit.FitArrange must contain exactly one model
514        and at least one data for the fit to be performed.
515       
516        model: the model selected by the user
517        Ldata: a list of data what the user wants to fit
518           
519        """
520        self.model = None
521        self.data_list = []
522        self.pars = []
523        self.vals = []
524        self.selected = 0
525
526    def set_model(self, model):
527        """
528        set_model save a copy of the model
529       
530        :param model: the model being set
531        """
532        self.model = model
533       
534    def add_data(self, data):
535        """
536        add_data fill a self.data_list with data to fit
537       
538        :param data: Data to add in the list
539        """
540        if not data in self.data_list:
541            self.data_list.append(data)
542           
543    def get_model(self):
544        """
545        :return: saved model
546        """
547        return self.model
548     
549    def get_data(self):
550        """
551        :return: list of data data_list
552        """
553        return self.data_list[0]
554     
555    def remove_data(self, data):
556        """
557        Remove one element from the list
558       
559        :param data: Data to remove from data_list
560        """
561        if data in self.data_list:
562            self.data_list.remove(data)
563           
564    def set_to_fit(self, value=0):
565        """
566        set self.selected to 0 or 1  for other values raise an exception
567       
568        :param value: integer between 0 or 1
569        """
570        self.selected = value
571       
572    def get_to_fit(self):
573        """
574        return self.selected value
575        """
576        return self.selected
577
578class FResult(object):
579    """
580    Storing fit result
581    """
582    def __init__(self, model=None, param_list=None, data=None):
583        self.calls = None
584        self.fitness = None
585        self.chisqr = None
586        self.pvec = []
587        self.cov = []
588        self.info = None
589        self.mesg = None
590        self.success = None
591        self.stderr = None
592        self.residuals = []
593        self.index = []
594        self.model = model
595        self.data = data
596        self.theory = []
597        self.param_list = param_list
598        self.iterations = 0
599        self.inputs = []
600        self.fitter_id = None
601        if self.model is not None and self.data is not None:
602            self.inputs = [(self.model, self.data)]
603     
604    def set_model(self, model):
605        """
606        """
607        self.model = model
608       
609    def set_fitness(self, fitness):
610        """
611        """
612        self.fitness = fitness
613       
614    def __str__(self):
615        """
616        """
617        if self.pvec == None and self.model is None and self.param_list is None:
618            return "No results"
619
620        sasmodel = self.model.model
621        pars = enumerate(sasmodel.getParamList())
622        msg1 = "[Iteration #: %s ]" % self.iterations
623        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness))
624        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v))
625                for i,v in pars if v in self.param_list]
626        msg = [msg1, msg3] + msg2
627        return "\n".join(msg)
628   
629    def print_summary(self):
630        """
631        """
632        print str(self)
Note: See TracBrowser for help on using the repository browser.