source: sasview/src/sas/sascalc/fit/AbstractFitEngine.py @ d619341

Last change on this file since d619341 was d3911e3, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

Reproduce sasview smear 2D wrapper so 2D fits work. Fixes #811, #825.

  • Property mode set to 100644
File size: 20.1 KB
Line 
1
2import  copy
3#import logging
4import sys
5import math
6import numpy
7
8from sas.sascalc.dataloader.data_info import Data1D
9from sas.sascalc.dataloader.data_info import Data2D
10_SMALLVALUE = 1.0e-10
11
12class FitHandler(object):
13    """
14    Abstract interface for fit thread handler.
15
16    The methods in this class are called by the optimizer as the fit
17    progresses.
18
19    Note that it is up to the optimizer to call the fit handler correctly,
20    reporting all status changes and maintaining the 'done' flag.
21    """
22    done = False
23    """True when the fit job is complete"""
24    result = None
25    """The current best result of the fit"""
26
27    def improvement(self):
28        """
29        Called when a result is observed which is better than previous
30        results from the fit.
31
32        result is a FitResult object, with parameters, #calls and fitness.
33        """
34    def error(self, msg):
35        """
36        Model had an error; print traceback
37        """
38    def progress(self, current, expected):
39        """
40        Called each cycle of the fit, reporting the current and the
41        expected amount of work.   The meaning of these values is
42        optimizer dependent, but they can be converted into a percent
43        complete using (100*current)//expected.
44
45        Progress is updated each iteration of the fit, whatever that
46        means for the particular optimization algorithm.  It is called
47        after any calls to improvement for the iteration so that the
48        update handler can control I/O bandwidth by suppressing
49        intermediate improvements until the fit is complete.
50        """
51    def finalize(self):
52        """
53        Fit is complete; best results are reported
54        """
55    def abort(self):
56        """
57        Fit was aborted.
58        """
59
60    # TODO: not sure how these are used, but they are needed for running the fit
61    def update_fit(self, last=False): pass
62    def set_result(self, result=None): self.result = result
63
64class Model:
65    """
66    Fit wrapper for SAS models.
67    """
68    def __init__(self, sas_model, sas_data=None, **kw):
69        """
70        :param sas_model: the sas model to wrap for fitting
71
72        """
73        self.model = sas_model
74        self.name = sas_model.name
75        self.data = sas_data
76
77    def get_params(self, fitparams):
78        """
79        return a list of value of paramter to fit
80
81        :param fitparams: list of paramaters name to fit
82
83        """
84        return [self.model.getParam(k) for k in fitparams]
85
86    def set_params(self, paramlist, params):
87        """
88        Set value for parameters to fit
89
90        :param params: list of value for parameters to fit
91
92        """
93        for k,v in zip(paramlist, params):
94            self.model.setParam(k,v)
95
96    def set(self, **kw):
97        self.set_params(*zip(*kw.items()))
98
99    def eval(self, x):
100        """
101            Override eval method of model.
102
103            :param x: the x value used to compute a function
104        """
105        try:
106            return self.model.evalDistribution(x)
107        except:
108            raise
109
110    def eval_derivs(self, x, pars=[]):
111        """
112        Evaluate the model and derivatives wrt pars at x.
113
114        pars is a list of the names of the parameters for which derivatives
115        are desired.
116
117        This method needs to be specialized in the model to evaluate the
118        model function.  Alternatively, the model can implement is own
119        version of residuals which calculates the residuals directly
120        instead of calling eval.
121        """
122        raise NotImplementedError('no derivatives available')
123
124    def __call__(self, x):
125        return self.eval(x)
126
127class FitData1D(Data1D):
128    """
129        Wrapper class  for SAS data
130        FitData1D inherits from DataLoader.data_info.Data1D. Implements
131        a way to get residuals from data.
132    """
133    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None):
134        """
135            :param smearer: is an object of class QSmearer or SlitSmearer
136               that will smear the theory data (slit smearing or resolution
137               smearing) when set.
138           
139            The proper way to set the smearing object would be to
140            do the following: ::
141           
142                from sas.sascalc.data_util.qsmearing import smear_selection
143                smearer = smear_selection(some_data)
144                fitdata1d = FitData1D( x= [1,3,..,],
145                                        y= [3,4,..,8],
146                                        dx=None,
147                                        dy=[1,2...], smearer= smearer)
148           
149            :Note: that some_data _HAS_ to be of
150                class DataLoader.data_info.Data1D
151                Setting it back to None will turn smearing off.
152               
153        """
154        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
155        self.num_points = len(x)
156        self.sas_data = data
157        self.smearer = smearer
158        self._first_unsmeared_bin = None
159        self._last_unsmeared_bin = None
160        # Check error bar; if no error bar found, set it constant(=1)
161        # TODO: Should provide an option for users to set it like percent,
162        # constant, or dy data
163        if dy is None or dy == [] or dy.all() == 0:
164            self.dy = numpy.ones(len(y))
165        else:
166            self.dy = numpy.asarray(dy).copy()
167
168        ## Min Q-value
169        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
170        if min(self.x) == 0.0 and self.x[0] == 0 and\
171                     not numpy.isfinite(self.y[0]):
172            self.qmin = min(self.x[self.x != 0])
173        else:
174            self.qmin = min(self.x)
175        ## Max Q-value
176        self.qmax = max(self.x)
177       
178        # Range used for input to smearing
179        self._qmin_unsmeared = self.qmin
180        self._qmax_unsmeared = self.qmax
181        # Identify the bin range for the unsmeared and smeared spaces
182        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
183        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
184                            & (self.x <= self._qmax_unsmeared)
185 
186    def set_fit_range(self, qmin=None, qmax=None):
187        """ to set the fit range"""
188        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
189        # ToDo: Find better way to do it.
190        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
191            self.qmin = min(self.x[self.x != 0])
192        elif qmin != None:
193            self.qmin = qmin
194        if qmax != None:
195            self.qmax = qmax
196        # Determine the range needed in unsmeared-Q to cover
197        # the smeared Q range
198        self._qmin_unsmeared = self.qmin
199        self._qmax_unsmeared = self.qmax
200       
201        self._first_unsmeared_bin = 0
202        self._last_unsmeared_bin = len(self.x) - 1
203       
204        if self.smearer != None:
205            self._first_unsmeared_bin, self._last_unsmeared_bin = \
206                    self.smearer.get_bin_range(self.qmin, self.qmax)
207            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
208            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
209           
210        # Identify the bin range for the unsmeared and smeared spaces
211        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
212        ## zero error can not participate for fitting
213        self.idx = self.idx & (self.dy != 0)
214        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
215                            & (self.x <= self._qmax_unsmeared)
216
217    def get_fit_range(self):
218        """
219            Return the range of data.x to fit
220        """
221        return self.qmin, self.qmax
222
223    def size(self):
224        """
225        Number of measurement points in data set after masking, etc.
226        """
227        return len(self.x)
228
229    def residuals(self, fn):
230        """
231            Compute residuals.
232           
233            If self.smearer has been set, use if to smear
234            the data before computing chi squared.
235           
236            :param fn: function that return model value
237           
238            :return: residuals
239        """
240        # Compute theory data f(x)
241        fx = numpy.zeros(len(self.x))
242        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
243       
244        ## Smear theory data
245        if self.smearer is not None:
246            fx = self.smearer(fx, self._first_unsmeared_bin,
247                              self._last_unsmeared_bin)
248        ## Sanity check
249        if numpy.size(self.dy) != numpy.size(fx):
250            msg = "FitData1D: invalid error array "
251            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))
252            raise RuntimeError, msg
253        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
254           
255    def residuals_deriv(self, model, pars=[]):
256        """
257            :return: residuals derivatives .
258           
259            :note: in this case just return empty array
260        """
261        return []
262
263
264class FitData2D(Data2D):
265    """
266        Wrapper class  for SAS data
267    """
268    def __init__(self, sas_data2d, data=None, err_data=None):
269        Data2D.__init__(self, data=data, err_data=err_data)
270        # Data can be initialized with a sas plottable or with vectors.
271        self.res_err_image = []
272        self.num_points = 0 # will be set by set_data
273        self.idx = []
274        self.qmin = None
275        self.qmax = None
276        self.smearer = None
277        self.radius = 0
278        self.res_err_data = []
279        self.sas_data = sas_data2d
280        self.set_data(sas_data2d)
281
282    def set_data(self, sas_data2d, qmin=None, qmax=None):
283        """
284            Determine the correct qx_data and qy_data within range to fit
285        """
286        self.data = sas_data2d.data
287        self.err_data = sas_data2d.err_data
288        self.qx_data = sas_data2d.qx_data
289        self.qy_data = sas_data2d.qy_data
290        self.mask = sas_data2d.mask
291
292        x_max = max(math.fabs(sas_data2d.xmin), math.fabs(sas_data2d.xmax))
293        y_max = max(math.fabs(sas_data2d.ymin), math.fabs(sas_data2d.ymax))
294       
295        ## fitting range
296        if qmin == None:
297            self.qmin = 1e-16
298        if qmax == None:
299            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
300        ## new error image for fitting purpose
301        if self.err_data == None or self.err_data == []:
302            self.res_err_data = numpy.ones(len(self.data))
303        else:
304            self.res_err_data = copy.deepcopy(self.err_data)
305        #self.res_err_data[self.res_err_data==0]=1
306       
307        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
308       
309        # Note: mask = True: for MASK while mask = False for NOT to mask
310        self.idx = ((self.qmin <= self.radius) &\
311                            (self.radius <= self.qmax))
312        self.idx = (self.idx) & (self.mask)
313        self.idx = (self.idx) & (numpy.isfinite(self.data))
314        self.num_points = numpy.sum(self.idx)
315
316    def set_smearer(self, smearer):
317        """
318            Set smearer
319        """
320        if smearer == None:
321            return
322        self.smearer = smearer
323        self.smearer.set_index(self.idx)
324        self.smearer.get_data()
325
326    def set_fit_range(self, qmin=None, qmax=None):
327        """
328            To set the fit range
329        """
330        if qmin == 0.0:
331            self.qmin = 1e-16
332        elif qmin != None:
333            self.qmin = qmin
334        if qmax != None:
335            self.qmax = qmax
336        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
337        self.idx = ((self.qmin <= self.radius) &\
338                            (self.radius <= self.qmax))
339        self.idx = (self.idx) & (self.mask)
340        self.idx = (self.idx) & (numpy.isfinite(self.data))
341        self.idx = (self.idx) & (self.res_err_data != 0)
342
343    def get_fit_range(self):
344        """
345        return the range of data.x to fit
346        """
347        return self.qmin, self.qmax
348
349    def size(self):
350        """
351        Number of measurement points in data set after masking, etc.
352        """
353        return numpy.sum(self.idx)
354
355    def residuals(self, fn):
356        """
357        return the residuals
358        """
359        if self.smearer != None:
360            fn.set_index(self.idx)
361            gn = fn.get_value()
362        else:
363            gn = fn([self.qx_data[self.idx],
364                     self.qy_data[self.idx]])
365        # use only the data point within ROI range
366        res = (self.data[self.idx] - gn) / self.res_err_data[self.idx]
367
368        return res, gn
369       
370    def residuals_deriv(self, model, pars=[]):
371        """
372        :return: residuals derivatives .
373       
374        :note: in this case just return empty array
375       
376        """
377        return []
378   
379   
380class FitAbort(Exception):
381    """
382    Exception raise to stop the fit
383    """
384    #pass
385    #print"Creating fit abort Exception"
386
387
388
389class FitEngine:
390    def __init__(self):
391        """
392        Base class for the fit engine
393        """
394        #Dictionnary of fitArrange element (fit problems)
395        self.fit_arrange_dict = {}
396        self.fitter_id = None
397       
398    def set_model(self, model, id, pars=[], constraints=[], data=None):
399        """
400        set a model on a given  in the fit engine.
401       
402        :param model: sas.models type
403        :param id: is the key of the fitArrange dictionary where model is saved as a value
404        :param pars: the list of parameters to fit
405        :param constraints: list of
406            tuple (name of parameter, value of parameters)
407            the value of parameter must be a string to constraint 2 different
408            parameters.
409            Example: 
410            we want to fit 2 model M1 and M2 both have parameters A and B.
411            constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]``
412           
413             
414        :note: pars must contains only name of existing model's parameters
415       
416        """
417        if not pars:
418            raise ValueError("no fitting parameters")
419
420        if model is None:
421            raise ValueError("no model to fit")
422
423        if not issubclass(model.__class__, Model):
424            model = Model(model, data)
425
426        sasmodel = model.model
427        available_parameters = sasmodel.getParamList()
428        for p in pars:
429            if p not in available_parameters:
430                raise ValueError("parameter %s not available in model %s; use one of [%s] instead"
431                                 %(p, sasmodel.name, ", ".join(available_parameters)))
432
433        if id not in self.fit_arrange_dict:
434            self.fit_arrange_dict[id] = FitArrange()
435
436        self.fit_arrange_dict[id].set_model(model)
437        self.fit_arrange_dict[id].pars = pars
438        self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars]
439        self.fit_arrange_dict[id].constraints = constraints
440
441    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
442        """
443        Receives plottable, creates a list of data to fit,set data
444        in a FitArrange object and adds that object in a dictionary
445        with key id.
446       
447        :param data: data added
448        :param id: unique key corresponding to a fitArrange object with data
449        """
450        if data.__class__.__name__ == 'Data2D':
451            fitdata = FitData2D(sas_data2d=data, data=data.data,
452                                 err_data=data.err_data)
453        else:
454            fitdata = FitData1D(x=data.x, y=data.y,
455                                 dx=data.dx, dy=data.dy, smearer=smearer)
456        fitdata.sas_data = data
457       
458        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
459        #A fitArrange is already created but contains model only at id
460        if id in self.fit_arrange_dict:
461            self.fit_arrange_dict[id].add_data(fitdata)
462        else:
463        #no fitArrange object has been create with this id
464            fitproblem = FitArrange()
465            fitproblem.add_data(fitdata)
466            self.fit_arrange_dict[id] = fitproblem
467   
468    def get_model(self, id):
469        """
470        :param id: id is key in the dictionary containing the model to return
471       
472        :return:  a model at this id or None if no FitArrange element was
473            created with this id
474        """
475        if id in self.fit_arrange_dict:
476            return self.fit_arrange_dict[id].get_model()
477        else:
478            return None
479   
480    def remove_fit_problem(self, id):
481        """remove   fitarrange in id"""
482        if id in self.fit_arrange_dict:
483            del self.fit_arrange_dict[id]
484           
485    def select_problem_for_fit(self, id, value):
486        """
487        select a couple of model and data at the id position in dictionary
488        and set in self.selected value to value
489       
490        :param value: the value to allow fitting.
491                can only have the value one or zero
492        """
493        if id in self.fit_arrange_dict:
494            self.fit_arrange_dict[id].set_to_fit(value)
495             
496    def get_problem_to_fit(self, id):
497        """
498        return the self.selected value of the fit problem of id
499       
500        :param id: the id of the problem
501        """
502        if id in self.fit_arrange_dict:
503            self.fit_arrange_dict[id].get_to_fit()
504   
505   
506class FitArrange:
507    def __init__(self):
508        """
509        Class FitArrange contains a set of data for a given model
510        to perform the Fit.FitArrange must contain exactly one model
511        and at least one data for the fit to be performed.
512       
513        model: the model selected by the user
514        Ldata: a list of data what the user wants to fit
515           
516        """
517        self.model = None
518        self.data_list = []
519        self.pars = []
520        self.vals = []
521        self.selected = 0
522
523    def set_model(self, model):
524        """
525        set_model save a copy of the model
526       
527        :param model: the model being set
528        """
529        self.model = model
530       
531    def add_data(self, data):
532        """
533        add_data fill a self.data_list with data to fit
534       
535        :param data: Data to add in the list
536        """
537        if not data in self.data_list:
538            self.data_list.append(data)
539           
540    def get_model(self):
541        """
542        :return: saved model
543        """
544        return self.model
545     
546    def get_data(self):
547        """
548        :return: list of data data_list
549        """
550        return self.data_list[0]
551     
552    def remove_data(self, data):
553        """
554        Remove one element from the list
555       
556        :param data: Data to remove from data_list
557        """
558        if data in self.data_list:
559            self.data_list.remove(data)
560           
561    def set_to_fit(self, value=0):
562        """
563        set self.selected to 0 or 1  for other values raise an exception
564       
565        :param value: integer between 0 or 1
566        """
567        self.selected = value
568       
569    def get_to_fit(self):
570        """
571        return self.selected value
572        """
573        return self.selected
574
575class FResult(object):
576    """
577    Storing fit result
578    """
579    def __init__(self, model=None, param_list=None, data=None):
580        self.calls = None
581        self.fitness = None
582        self.chisqr = None
583        self.pvec = []
584        self.cov = []
585        self.info = None
586        self.mesg = None
587        self.success = None
588        self.stderr = None
589        self.residuals = []
590        self.index = []
591        self.model = model
592        self.data = data
593        self.theory = []
594        self.param_list = param_list
595        self.iterations = 0
596        self.inputs = []
597        self.fitter_id = None
598        if self.model is not None and self.data is not None:
599            self.inputs = [(self.model, self.data)]
600     
601    def set_model(self, model):
602        """
603        """
604        self.model = model
605       
606    def set_fitness(self, fitness):
607        """
608        """
609        self.fitness = fitness
610       
611    def __str__(self):
612        """
613        """
614        if self.pvec == None and self.model is None and self.param_list is None:
615            return "No results"
616
617        sasmodel = self.model.model
618        pars = enumerate(sasmodel.getParamList())
619        msg1 = "[Iteration #: %s ]" % self.iterations
620        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness))
621        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v))
622                for i,v in pars if v in self.param_list]
623        msg = [msg1, msg3] + msg2
624        return "\n".join(msg)
625   
626    def print_summary(self):
627        """
628        """
629        print str(self)
Note: See TracBrowser for help on using the repository browser.