source: sasview/src/sas/sascalc/fit/AbstractFitEngine.py @ f4e2f22

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249unittest-saveload
Last change on this file since f4e2f22 was 20fa5fe, checked in by Stuart Prescott <stuart@…>, 7 years ago

Fix lots more typos in comments and docs

  • Property mode set to 100644
File size: 19.7 KB
RevLine 
[a1b8fee]1from __future__ import print_function
[51f14603]2
3import  copy
4#import logging
5import sys
6import math
[9a5097c]7import numpy as np
[6fe5100]8
[b699768]9from sas.sascalc.dataloader.data_info import Data1D
10from sas.sascalc.dataloader.data_info import Data2D
[6fe5100]11_SMALLVALUE = 1.0e-10
12
13class FitHandler(object):
[51f14603]14    """
[6fe5100]15    Abstract interface for fit thread handler.
16
17    The methods in this class are called by the optimizer as the fit
18    progresses.
19
20    Note that it is up to the optimizer to call the fit handler correctly,
21    reporting all status changes and maintaining the 'done' flag.
[51f14603]22    """
[6fe5100]23    done = False
24    """True when the fit job is complete"""
25    result = None
26    """The current best result of the fit"""
27
28    def improvement(self):
[51f14603]29        """
[6fe5100]30        Called when a result is observed which is better than previous
31        results from the fit.
32
33        result is a FitResult object, with parameters, #calls and fitness.
[51f14603]34        """
[6fe5100]35    def error(self, msg):
[51f14603]36        """
[6fe5100]37        Model had an error; print traceback
[51f14603]38        """
[6fe5100]39    def progress(self, current, expected):
[51f14603]40        """
[6fe5100]41        Called each cycle of the fit, reporting the current and the
42        expected amount of work.   The meaning of these values is
43        optimizer dependent, but they can be converted into a percent
44        complete using (100*current)//expected.
45
46        Progress is updated each iteration of the fit, whatever that
47        means for the particular optimization algorithm.  It is called
48        after any calls to improvement for the iteration so that the
49        update handler can control I/O bandwidth by suppressing
50        intermediate improvements until the fit is complete.
[51f14603]51        """
[6fe5100]52    def finalize(self):
[51f14603]53        """
[6fe5100]54        Fit is complete; best results are reported
[51f14603]55        """
[6fe5100]56    def abort(self):
[51f14603]57        """
[6fe5100]58        Fit was aborted.
[51f14603]59        """
[6fe5100]60
[95d58d3]61    # TODO: not sure how these are used, but they are needed for running the fit
62    def update_fit(self, last=False): pass
63    def set_result(self, result=None): self.result = result
64
[6fe5100]65class Model:
[51f14603]66    """
[fd5ac0d]67    Fit wrapper for SAS models.
[51f14603]68    """
[fd5ac0d]69    def __init__(self, sas_model, sas_data=None, **kw):
[51f14603]70        """
[386ffe1]71        :param sas_model: the sas model to wrap for fitting
[6fe5100]72
[51f14603]73        """
[fd5ac0d]74        self.model = sas_model
75        self.name = sas_model.name
76        self.data = sas_data
[6fe5100]77
[51f14603]78    def get_params(self, fitparams):
79        """
[20fa5fe]80        return a list of value of parameter to fit
[6fe5100]81
[20fa5fe]82        :param fitparams: list of parameters name to fit
[6fe5100]83
[51f14603]84        """
[6fe5100]85        return [self.model.getParam(k) for k in fitparams]
86
[51f14603]87    def set_params(self, paramlist, params):
88        """
89        Set value for parameters to fit
[6fe5100]90
[51f14603]91        :param params: list of value for parameters to fit
[6fe5100]92
[51f14603]93        """
[6fe5100]94        for k,v in zip(paramlist, params):
95            self.model.setParam(k,v)
96
97    def set(self, **kw):
98        self.set_params(*zip(*kw.items()))
99
[51f14603]100    def eval(self, x):
101        """
[386ffe1]102            Override eval method of model.
[6fe5100]103
[51f14603]104            :param x: the x value used to compute a function
105        """
106        try:
107            return self.model.evalDistribution(x)
108        except:
109            raise
[6fe5100]110
[51f14603]111    def eval_derivs(self, x, pars=[]):
112        """
113        Evaluate the model and derivatives wrt pars at x.
114
115        pars is a list of the names of the parameters for which derivatives
116        are desired.
117
118        This method needs to be specialized in the model to evaluate the
119        model function.  Alternatively, the model can implement is own
120        version of residuals which calculates the residuals directly
121        instead of calling eval.
122        """
[6fe5100]123        raise NotImplementedError('no derivatives available')
124
125    def __call__(self, x):
126        return self.eval(x)
[51f14603]127
128class FitData1D(Data1D):
129    """
[fd5ac0d]130        Wrapper class  for SAS data
[51f14603]131        FitData1D inherits from DataLoader.data_info.Data1D. Implements
132        a way to get residuals from data.
133    """
[a9f579c]134    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None, lam=None, dlam=None):
[51f14603]135        """
136            :param smearer: is an object of class QSmearer or SlitSmearer
137               that will smear the theory data (slit smearing or resolution
138               smearing) when set.
[50fcb09]139
[51f14603]140            The proper way to set the smearing object would be to
141            do the following: ::
[50fcb09]142
143                from sas.sascalc.fit.qsmearing import smear_selection
[51f14603]144                smearer = smear_selection(some_data)
145                fitdata1d = FitData1D( x= [1,3,..,],
146                                        y= [3,4,..,8],
147                                        dx=None,
148                                        dy=[1,2...], smearer= smearer)
[50fcb09]149
[51f14603]150            :Note: that some_data _HAS_ to be of
151                class DataLoader.data_info.Data1D
152                Setting it back to None will turn smearing off.
[50fcb09]153
[51f14603]154        """
[a9f579c]155        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy, lam=lam, dlam=dlam)
[6fe5100]156        self.num_points = len(x)
[fd5ac0d]157        self.sas_data = data
[51f14603]158        self.smearer = smearer
159        self._first_unsmeared_bin = None
160        self._last_unsmeared_bin = None
161        # Check error bar; if no error bar found, set it constant(=1)
162        # TODO: Should provide an option for users to set it like percent,
163        # constant, or dy data
[9f7fbd9]164        if dy is None or dy == [] or dy.all() == 0:
[9a5097c]165            self.dy = np.ones(len(y))
[51f14603]166        else:
[9a5097c]167            self.dy = np.asarray(dy).copy()
[51f14603]168
169        ## Min Q-value
170        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
171        if min(self.x) == 0.0 and self.x[0] == 0 and\
[9a5097c]172                     not np.isfinite(self.y[0]):
[51f14603]173            self.qmin = min(self.x[self.x != 0])
174        else:
175            self.qmin = min(self.x)
176        ## Max Q-value
177        self.qmax = max(self.x)
[50fcb09]178
[51f14603]179        # Range used for input to smearing
180        self._qmin_unsmeared = self.qmin
181        self._qmax_unsmeared = self.qmax
182        # Identify the bin range for the unsmeared and smeared spaces
183        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
184        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
185                            & (self.x <= self._qmax_unsmeared)
[50fcb09]186
[51f14603]187    def set_fit_range(self, qmin=None, qmax=None):
188        """ to set the fit range"""
189        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
190        # ToDo: Find better way to do it.
[9a5097c]191        if qmin == 0.0 and not np.isfinite(self.y[qmin]):
[51f14603]192            self.qmin = min(self.x[self.x != 0])
[7432acb]193        elif qmin is not None:
[51f14603]194            self.qmin = qmin
[7432acb]195        if qmax is not None:
[51f14603]196            self.qmax = qmax
197        # Determine the range needed in unsmeared-Q to cover
198        # the smeared Q range
199        self._qmin_unsmeared = self.qmin
200        self._qmax_unsmeared = self.qmax
[50fcb09]201
[51f14603]202        self._first_unsmeared_bin = 0
203        self._last_unsmeared_bin = len(self.x) - 1
[50fcb09]204
[7432acb]205        if self.smearer is not None:
[51f14603]206            self._first_unsmeared_bin, self._last_unsmeared_bin = \
207                    self.smearer.get_bin_range(self.qmin, self.qmax)
208            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
209            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
[50fcb09]210
[51f14603]211        # Identify the bin range for the unsmeared and smeared spaces
212        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
213        ## zero error can not participate for fitting
214        self.idx = self.idx & (self.dy != 0)
215        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
216                            & (self.x <= self._qmax_unsmeared)
217
218    def get_fit_range(self):
219        """
220            Return the range of data.x to fit
221        """
222        return self.qmin, self.qmax
[95d58d3]223
224    def size(self):
225        """
226        Number of measurement points in data set after masking, etc.
227        """
228        return len(self.x)
229
[51f14603]230    def residuals(self, fn):
231        """
232            Compute residuals.
[50fcb09]233
[51f14603]234            If self.smearer has been set, use if to smear
235            the data before computing chi squared.
[50fcb09]236
[51f14603]237            :param fn: function that return model value
[50fcb09]238
[51f14603]239            :return: residuals
240        """
241        # Compute theory data f(x)
[9a5097c]242        fx = np.zeros(len(self.x))
[51f14603]243        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
[50fcb09]244
[51f14603]245        ## Smear theory data
246        if self.smearer is not None:
247            fx = self.smearer(fx, self._first_unsmeared_bin,
248                              self._last_unsmeared_bin)
249        ## Sanity check
[9a5097c]250        if np.size(self.dy) != np.size(fx):
[51f14603]251            msg = "FitData1D: invalid error array "
[9a5097c]252            msg += "%d <> %d" % (np.shape(self.dy), np.size(fx))
[574adc7]253            raise RuntimeError(msg)
[51f14603]254        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
[50fcb09]255
[51f14603]256    def residuals_deriv(self, model, pars=[]):
257        """
258            :return: residuals derivatives .
[50fcb09]259
260            :note: in this case just return empty array
[51f14603]261        """
262        return []
[9f7fbd9]263
264
[51f14603]265class FitData2D(Data2D):
266    """
[fd5ac0d]267        Wrapper class  for SAS data
[51f14603]268    """
[fd5ac0d]269    def __init__(self, sas_data2d, data=None, err_data=None):
[51f14603]270        Data2D.__init__(self, data=data, err_data=err_data)
[79492222]271        # Data can be initialized with a sas plottable or with vectors.
[51f14603]272        self.res_err_image = []
[95d58d3]273        self.num_points = 0 # will be set by set_data
[51f14603]274        self.idx = []
275        self.qmin = None
276        self.qmax = None
277        self.smearer = None
278        self.radius = 0
279        self.res_err_data = []
[fd5ac0d]280        self.sas_data = sas_data2d
281        self.set_data(sas_data2d)
[51f14603]282
[fd5ac0d]283    def set_data(self, sas_data2d, qmin=None, qmax=None):
[51f14603]284        """
285            Determine the correct qx_data and qy_data within range to fit
286        """
[fd5ac0d]287        self.data = sas_data2d.data
288        self.err_data = sas_data2d.err_data
289        self.qx_data = sas_data2d.qx_data
290        self.qy_data = sas_data2d.qy_data
291        self.mask = sas_data2d.mask
[51f14603]292
[fd5ac0d]293        x_max = max(math.fabs(sas_data2d.xmin), math.fabs(sas_data2d.xmax))
294        y_max = max(math.fabs(sas_data2d.ymin), math.fabs(sas_data2d.ymax))
[50fcb09]295
[51f14603]296        ## fitting range
[235f514]297        if qmin is None:
[51f14603]298            self.qmin = 1e-16
[235f514]299        if qmax is None:
[51f14603]300            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
301        ## new error image for fitting purpose
[235f514]302        if self.err_data is None or self.err_data == []:
[9a5097c]303            self.res_err_data = np.ones(len(self.data))
[51f14603]304        else:
305            self.res_err_data = copy.deepcopy(self.err_data)
306        #self.res_err_data[self.res_err_data==0]=1
[50fcb09]307
[9a5097c]308        self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2)
[50fcb09]309
[51f14603]310        # Note: mask = True: for MASK while mask = False for NOT to mask
311        self.idx = ((self.qmin <= self.radius) &\
312                            (self.radius <= self.qmax))
313        self.idx = (self.idx) & (self.mask)
[9a5097c]314        self.idx = (self.idx) & (np.isfinite(self.data))
315        self.num_points = np.sum(self.idx)
[51f14603]316
317    def set_smearer(self, smearer):
318        """
319            Set smearer
320        """
[235f514]321        if smearer is None:
[51f14603]322            return
323        self.smearer = smearer
324        self.smearer.set_index(self.idx)
325        self.smearer.get_data()
326
327    def set_fit_range(self, qmin=None, qmax=None):
328        """
329            To set the fit range
330        """
331        if qmin == 0.0:
332            self.qmin = 1e-16
[7432acb]333        elif qmin is not None:
[51f14603]334            self.qmin = qmin
[7432acb]335        if qmax is not None:
[51f14603]336            self.qmax = qmax
[9a5097c]337        self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2)
[51f14603]338        self.idx = ((self.qmin <= self.radius) &\
339                            (self.radius <= self.qmax))
340        self.idx = (self.idx) & (self.mask)
[9a5097c]341        self.idx = (self.idx) & (np.isfinite(self.data))
[51f14603]342        self.idx = (self.idx) & (self.res_err_data != 0)
343
344    def get_fit_range(self):
345        """
346        return the range of data.x to fit
347        """
348        return self.qmin, self.qmax
[95d58d3]349
350    def size(self):
351        """
352        Number of measurement points in data set after masking, etc.
353        """
[9a5097c]354        return np.sum(self.idx)
[95d58d3]355
[51f14603]356    def residuals(self, fn):
357        """
358        return the residuals
359        """
[7432acb]360        if self.smearer is not None:
[51f14603]361            fn.set_index(self.idx)
362            gn = fn.get_value()
363        else:
364            gn = fn([self.qx_data[self.idx],
365                     self.qy_data[self.idx]])
366        # use only the data point within ROI range
367        res = (self.data[self.idx] - gn) / self.res_err_data[self.idx]
368
369        return res, gn
[50fcb09]370
[51f14603]371    def residuals_deriv(self, model, pars=[]):
372        """
373        :return: residuals derivatives .
[50fcb09]374
[51f14603]375        :note: in this case just return empty array
[50fcb09]376
[51f14603]377        """
378        return []
[50fcb09]379
380
[51f14603]381class FitAbort(Exception):
382    """
383    Exception raise to stop the fit
384    """
385    #pass
386    #print"Creating fit abort Exception"
387
388
389
390class FitEngine:
391    def __init__(self):
392        """
[386ffe1]393        Base class for the fit engine
[51f14603]394        """
395        #Dictionnary of fitArrange element (fit problems)
396        self.fit_arrange_dict = {}
397        self.fitter_id = None
[50fcb09]398
[51f14603]399    def set_model(self, model, id, pars=[], constraints=[], data=None):
400        """
401        set a model on a given  in the fit engine.
[50fcb09]402
403        :param model: sas.models type
[51f14603]404        :param id: is the key of the fitArrange dictionary where model is saved as a value
[50fcb09]405        :param pars: the list of parameters to fit
406        :param constraints: list of
[51f14603]407            tuple (name of parameter, value of parameters)
408            the value of parameter must be a string to constraint 2 different
409            parameters.
[50fcb09]410            Example:
[51f14603]411            we want to fit 2 model M1 and M2 both have parameters A and B.
412            constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]``
[50fcb09]413
414
[51f14603]415        :note: pars must contains only name of existing model's parameters
[50fcb09]416
[51f14603]417        """
[8d074d9]418        if not pars:
419            raise ValueError("no fitting parameters")
420
421        if model is None:
422            raise ValueError("no model to fit")
423
[51f14603]424        if not issubclass(model.__class__, Model):
[95d58d3]425            model = Model(model, data)
426
427        sasmodel = model.model
[8d074d9]428        available_parameters = sasmodel.getParamList()
429        for p in pars:
430            if p not in available_parameters:
431                raise ValueError("parameter %s not available in model %s; use one of [%s] instead"
432                                 %(p, sasmodel.name, ", ".join(available_parameters)))
433
434        if id not in self.fit_arrange_dict:
435            self.fit_arrange_dict[id] = FitArrange()
436
437        self.fit_arrange_dict[id].set_model(model)
438        self.fit_arrange_dict[id].pars = pars
439        self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars]
440        self.fit_arrange_dict[id].constraints = constraints
441
[51f14603]442    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
443        """
444        Receives plottable, creates a list of data to fit,set data
445        in a FitArrange object and adds that object in a dictionary
446        with key id.
[50fcb09]447
[51f14603]448        :param data: data added
449        :param id: unique key corresponding to a fitArrange object with data
450        """
451        if data.__class__.__name__ == 'Data2D':
[fd5ac0d]452            fitdata = FitData2D(sas_data2d=data, data=data.data,
[51f14603]453                                 err_data=data.err_data)
454        else:
455            fitdata = FitData1D(x=data.x, y=data.y,
456                                 dx=data.dx, dy=data.dy, smearer=smearer)
[fd5ac0d]457        fitdata.sas_data = data
[50fcb09]458
[51f14603]459        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
460        #A fitArrange is already created but contains model only at id
461        if id in self.fit_arrange_dict:
462            self.fit_arrange_dict[id].add_data(fitdata)
463        else:
464        #no fitArrange object has been create with this id
465            fitproblem = FitArrange()
466            fitproblem.add_data(fitdata)
467            self.fit_arrange_dict[id] = fitproblem
[50fcb09]468
[51f14603]469    def get_model(self, id):
470        """
471        :param id: id is key in the dictionary containing the model to return
[50fcb09]472
[51f14603]473        :return:  a model at this id or None if no FitArrange element was
474            created with this id
475        """
476        if id in self.fit_arrange_dict:
477            return self.fit_arrange_dict[id].get_model()
478        else:
479            return None
[50fcb09]480
[51f14603]481    def remove_fit_problem(self, id):
482        """remove   fitarrange in id"""
483        if id in self.fit_arrange_dict:
484            del self.fit_arrange_dict[id]
[50fcb09]485
[51f14603]486    def select_problem_for_fit(self, id, value):
487        """
488        select a couple of model and data at the id position in dictionary
489        and set in self.selected value to value
[50fcb09]490
[51f14603]491        :param value: the value to allow fitting.
492                can only have the value one or zero
493        """
494        if id in self.fit_arrange_dict:
495            self.fit_arrange_dict[id].set_to_fit(value)
[50fcb09]496
[51f14603]497    def get_problem_to_fit(self, id):
498        """
499        return the self.selected value of the fit problem of id
[50fcb09]500
[51f14603]501        :param id: the id of the problem
502        """
503        if id in self.fit_arrange_dict:
504            self.fit_arrange_dict[id].get_to_fit()
[50fcb09]505
506
[51f14603]507class FitArrange:
508    def __init__(self):
509        """
510        Class FitArrange contains a set of data for a given model
511        to perform the Fit.FitArrange must contain exactly one model
512        and at least one data for the fit to be performed.
[50fcb09]513
[51f14603]514        model: the model selected by the user
515        Ldata: a list of data what the user wants to fit
[50fcb09]516
[51f14603]517        """
518        self.model = None
519        self.data_list = []
520        self.pars = []
521        self.vals = []
522        self.selected = 0
[8d074d9]523
[51f14603]524    def set_model(self, model):
525        """
526        set_model save a copy of the model
[50fcb09]527
[51f14603]528        :param model: the model being set
529        """
530        self.model = model
[50fcb09]531
[51f14603]532    def add_data(self, data):
533        """
534        add_data fill a self.data_list with data to fit
[50fcb09]535
[51f14603]536        :param data: Data to add in the list
537        """
538        if not data in self.data_list:
539            self.data_list.append(data)
[50fcb09]540
[51f14603]541    def get_model(self):
542        """
543        :return: saved model
544        """
545        return self.model
[50fcb09]546
[51f14603]547    def get_data(self):
548        """
549        :return: list of data data_list
550        """
551        return self.data_list[0]
[50fcb09]552
[51f14603]553    def remove_data(self, data):
554        """
555        Remove one element from the list
[50fcb09]556
[51f14603]557        :param data: Data to remove from data_list
558        """
559        if data in self.data_list:
560            self.data_list.remove(data)
[50fcb09]561
[51f14603]562    def set_to_fit(self, value=0):
563        """
564        set self.selected to 0 or 1  for other values raise an exception
[50fcb09]565
[51f14603]566        :param value: integer between 0 or 1
567        """
568        self.selected = value
[50fcb09]569
[51f14603]570    def get_to_fit(self):
571        """
572        return self.selected value
573        """
574        return self.selected
[8d074d9]575
[51f14603]576class FResult(object):
577    """
578    Storing fit result
579    """
580    def __init__(self, model=None, param_list=None, data=None):
581        self.calls = None
582        self.fitness = None
583        self.chisqr = None
584        self.pvec = []
585        self.cov = []
586        self.info = None
587        self.mesg = None
588        self.success = None
589        self.stderr = None
590        self.residuals = []
591        self.index = []
592        self.model = model
593        self.data = data
594        self.theory = []
595        self.param_list = param_list
596        self.iterations = 0
597        self.inputs = []
598        self.fitter_id = None
599        if self.model is not None and self.data is not None:
600            self.inputs = [(self.model, self.data)]
[50fcb09]601
[51f14603]602    def set_model(self, model):
603        """
604        """
605        self.model = model
[50fcb09]606
[51f14603]607    def set_fitness(self, fitness):
608        """
609        """
610        self.fitness = fitness
[50fcb09]611
[51f14603]612    def __str__(self):
613        """
614        """
[235f514]615        if self.pvec is None and self.model is None and self.param_list is None:
[51f14603]616            return "No results"
[6fe5100]617
[95d58d3]618        sasmodel = self.model.model
619        pars = enumerate(sasmodel.getParamList())
[6fe5100]620        msg1 = "[Iteration #: %s ]" % self.iterations
621        msg3 = "=== goodness of fit: %s ===" % (str(self.fitness))
[95d58d3]622        msg2 = ["P%-3d  %s......|.....%s" % (i, v, sasmodel.getParam(v))
[6fe5100]623                for i,v in pars if v in self.param_list]
624        msg = [msg1, msg3] + msg2
625        return "\n".join(msg)
[50fcb09]626
[51f14603]627    def print_summary(self):
628        """
629        """
[9c3d784]630        print(str(self))
Note: See TracBrowser for help on using the repository browser.