source: sasview/src/sans/fit/AbstractFitEngine.py @ f2148b2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f2148b2 was 5777106, checked in by Mathieu Doucet <doucetm@…>, 11 years ago

Moving things around. Will definitely not build.

  • Property mode set to 100644
File size: 27.6 KB
RevLine 
[aa36f96]1
[89f3b66]2import  copy
[c4d6900]3#import logging
[444c900e]4import sys
[89f3b66]5import numpy
6import math
7import park
[41517a6]8from sans.dataloader.data_info import Data1D
9from sans.dataloader.data_info import Data2D
[4b5bd73]10_SMALLVALUE = 1.0e-10   
[b2f25dc5]11   
[48882d1]12class SansParameter(park.Parameter):
13    """
[aa36f96]14    SANS model parameters for use in the PARK fitting service.
15    The parameter attribute value is redirected to the underlying
16    parameter value in the SANS model.
[48882d1]17    """
[1cff677]18    def __init__(self, name, model, data):
[ca6d914]19        """
[1f9f3c8a]20            :param name: the name of the model parameter
21            :param model: the sans model to wrap as a park model
[ca6d914]22        """
[c4d6900]23        park.Parameter.__init__(self, name)
[89f3b66]24        self._model, self._name = model, name
[1cff677]25        self.data = data
26        self.model = model
[ca6d914]27        #set the value for the parameter of the given name
28        self.set(model.getParam(name))
[48882d1]29         
[ca6d914]30    def _getvalue(self):
31        """
[aa36f96]32        override the _getvalue of park parameter
33       
34        :return value the parameter associates with self.name
35       
[ca6d914]36        """
37        return self._model.getParam(self.name)
[48882d1]38   
[89f3b66]39    def _setvalue(self, value):
[ca6d914]40        """
[aa36f96]41        override the _setvalue pf park parameter
42       
43        :param value: the value to set on a given parameter
44       
[ca6d914]45        """
[48882d1]46        self._model.setParam(self.name, value)
47       
[c4d6900]48    value = property(_getvalue, _setvalue)
[48882d1]49   
50    def _getrange(self):
[ca6d914]51        """
[aa36f96]52        Override _getrange of park parameter
53        return the range of parameter
[ca6d914]54        """
[920a6e5]55        #if not  self.name in self._model.getDispParamList():
[89f3b66]56        lo, hi = self._model.details[self.name][1:3]
[920a6e5]57        if lo is None: lo = -numpy.inf
58        if hi is None: hi = numpy.inf
[e0e22f2c]59        if lo > hi:
[1f9f3c8a]60            raise ValueError, "wrong fit range for parameters"
[05f14dd]61       
[89f3b66]62        return lo, hi
[48882d1]63   
[b2f25dc5]64    def get_name(self):
65        """
66        """
67        return self._getname()
68   
[89f3b66]69    def _setrange(self, r):
[ca6d914]70        """
[aa36f96]71        override _setrange of park parameter
72       
73        :param r: the value of the range to set
74       
[ca6d914]75        """
[12b76cf]76        self._model.details[self.name][1:3] = r
[89f3b66]77    range = property(_getrange, _setrange)
[a9e04aa]78   
[1f9f3c8a]79   
[a9e04aa]80class Model(park.Model):
[48882d1]81    """
[aa36f96]82    PARK wrapper for SANS models.
[48882d1]83    """
[1cff677]84    def __init__(self, sans_model, sans_data=None, **kw):
[ca6d914]85        """
[aa36f96]86        :param sans_model: the sans model to wrap using park interface
87       
[ca6d914]88        """
[a9e04aa]89        park.Model.__init__(self, **kw)
[48882d1]90        self.model = sans_model
[ca6d914]91        self.name = sans_model.name
[1cff677]92        self.data = sans_data
[ca6d914]93        #list of parameters names
[48882d1]94        self.sansp = sans_model.getParamList()
[ca6d914]95        #list of park parameter
[1cff677]96        self.parkp = [SansParameter(p, sans_model, sans_data) for p in self.sansp]
[1f9f3c8a]97        #list of parameter set
[89f3b66]98        self.parameterset = park.ParameterSet(sans_model.name, pars=self.parkp)
99        self.pars = []
[ca6d914]100 
[c4d6900]101    def get_params(self, fitparams):
[ca6d914]102        """
[aa36f96]103        return a list of value of paramter to fit
104       
105        :param fitparams: list of paramaters name to fit
106       
[ca6d914]107        """
[c4d6900]108        list_params = []
[89f3b66]109        self.pars = []
110        self.pars = fitparams
[48882d1]111        for item in fitparams:
112            for element in self.parkp:
[c4d6900]113                if element.name == str(item):
114                    list_params.append(element.value)
115        return list_params
[48882d1]116   
[c4d6900]117    def set_params(self, paramlist, params):
[ca6d914]118        """
[aa36f96]119        Set value for parameters to fit
120       
[1f9f3c8a]121        :param params: list of value for parameters to fit
[aa36f96]122       
[ca6d914]123        """
[e71440c]124        try:
125            for i in range(len(self.parkp)):
126                for j in range(len(paramlist)):
[89f3b66]127                    if self.parkp[i].name == paramlist[j]:
[e71440c]128                        self.parkp[i].value = params[j]
[89f3b66]129                        self.model.setParam(self.parkp[i].name, params[j])
[e71440c]130        except:
131            raise
[ca6d914]132 
[89f3b66]133    def eval(self, x):
[ca6d914]134        """
[1f9f3c8a]135            Override eval method of park model.
[aa36f96]136       
[1f9f3c8a]137            :param x: the x value used to compute a function
[ca6d914]138        """
[d8a2e31]139        try:
[393f0f3]140            return self.model.evalDistribution(x)
[d8a2e31]141        except:
[393f0f3]142            raise
[c4d6900]143       
144    def eval_derivs(self, x, pars=[]):
145        """
146        Evaluate the model and derivatives wrt pars at x.
147
148        pars is a list of the names of the parameters for which derivatives
149        are desired.
150
151        This method needs to be specialized in the model to evaluate the
152        model function.  Alternatively, the model can implement is own
153        version of residuals which calculates the residuals directly
154        instead of calling eval.
155        """
156        return []
157
[b64fa56]158   
[1e3169c]159class FitData1D(Data1D):
[1f9f3c8a]160    """
161        Wrapper class  for SANS data
162        FitData1D inherits from DataLoader.data_info.Data1D. Implements
163        a way to get residuals from data.
[1e3169c]164    """
[634ca14]165    def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None):
[7d0c1a8]166        """
[1f9f3c8a]167            :param smearer: is an object of class QSmearer or SlitSmearer
168               that will smear the theory data (slit smearing or resolution
169               smearing) when set.
170           
171            The proper way to set the smearing object would be to
172            do the following: ::
[109e60ab]173           
[1f9f3c8a]174                from DataLoader.qsmearing import smear_selection
175                smearer = smear_selection(some_data)
176                fitdata1d = FitData1D( x= [1,3,..,],
177                                        y= [3,4,..,8],
178                                        dx=None,
179                                        dy=[1,2...], smearer= smearer)
180           
181            :Note: that some_data _HAS_ to be of
182                class DataLoader.data_info.Data1D
183                Setting it back to None will turn smearing off.
184               
[7d0c1a8]185        """
[89f3b66]186        Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy)
[634ca14]187        self.sans_data = data
[b461b6d7]188        self.smearer = smearer
[c4d6900]189        self._first_unsmeared_bin = None
190        self._last_unsmeared_bin = None
[189be4e]191        # Check error bar; if no error bar found, set it constant(=1)
[c4d6900]192        # TODO: Should provide an option for users to set it like percent,
193        # constant, or dy data
[89f3b66]194        if dy == None or dy == [] or dy.all() == 0:
[1f9f3c8a]195            self.dy = numpy.ones(len(y))
[189be4e]196        else:
[89f3b66]197            self.dy = numpy.asarray(dy).copy()
[189be4e]198
[109e60ab]199        ## Min Q-value
[4bd557d]200        #Skip the Q=0 point, especially when y(q=0)=None at x[0].
[1f9f3c8a]201        if min(self.x) == 0.0 and self.x[0] == 0 and\
[89f3b66]202                     not numpy.isfinite(self.y[0]):
[1f9f3c8a]203            self.qmin = min(self.x[self.x != 0])
204        else:
[89f3b66]205            self.qmin = min(self.x)
[109e60ab]206        ## Max Q-value
[89f3b66]207        self.qmax = max(self.x)
[058b2d7]208       
[72c7d31]209        # Range used for input to smearing
210        self._qmin_unsmeared = self.qmin
211        self._qmax_unsmeared = self.qmax
[fd0d30fd]212        # Identify the bin range for the unsmeared and smeared spaces
[89f3b66]213        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
214        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
215                            & (self.x <= self._qmax_unsmeared)
[fd0d30fd]216 
[c4d6900]217    def set_fit_range(self, qmin=None, qmax=None):
[7d0c1a8]218        """ to set the fit range"""
[09975cbb]219        # Skip Q=0 point, (especially for y(q=0)=None at x[0]).
[189be4e]220        # ToDo: Find better way to do it.
[89f3b66]221        if qmin == 0.0 and not numpy.isfinite(self.y[qmin]):
222            self.qmin = min(self.x[self.x != 0])
[1f9f3c8a]223        elif qmin != None:
224            self.qmin = qmin
[89f3b66]225        if qmax != None:
[eef2e0ed]226            self.qmax = qmax
[4bb2917]227        # Determine the range needed in unsmeared-Q to cover
228        # the smeared Q range
[72c7d31]229        self._qmin_unsmeared = self.qmin
[1f9f3c8a]230        self._qmax_unsmeared = self.qmax
[72c7d31]231       
[4bb2917]232        self._first_unsmeared_bin = 0
[1f9f3c8a]233        self._last_unsmeared_bin = len(self.x) - 1
[4bb2917]234       
[c4d6900]235        if self.smearer != None:
[89f3b66]236            self._first_unsmeared_bin, self._last_unsmeared_bin = \
237                    self.smearer.get_bin_range(self.qmin, self.qmax)
[1e3169c]238            self._qmin_unsmeared = self.x[self._first_unsmeared_bin]
239            self._qmax_unsmeared = self.x[self._last_unsmeared_bin]
[4bb2917]240           
[fd0d30fd]241        # Identify the bin range for the unsmeared and smeared spaces
[89f3b66]242        self.idx = (self.x >= self.qmin) & (self.x <= self.qmax)
243        ## zero error can not participate for fitting
[1f9f3c8a]244        self.idx = self.idx & (self.dy != 0)
[89f3b66]245        self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \
246                            & (self.x <= self._qmax_unsmeared)
[0766d6d]247
[c4d6900]248    def get_fit_range(self):
[7d0c1a8]249        """
[1f9f3c8a]250            Return the range of data.x to fit
[7d0c1a8]251        """
252        return self.qmin, self.qmax
[72c7d31]253       
[7d0c1a8]254    def residuals(self, fn):
[1f9f3c8a]255        """
256            Compute residuals.
257           
258            If self.smearer has been set, use if to smear
259            the data before computing chi squared.
260           
261            :param fn: function that return model value
262           
263            :return: residuals
[109e60ab]264        """
265        # Compute theory data f(x)
[89f3b66]266        fx = numpy.zeros(len(self.x))
[7e752fe]267        fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared])
[fd0d30fd]268       
[d5b488b]269        ## Smear theory data
[109e60ab]270        if self.smearer is not None:
[1f9f3c8a]271            fx = self.smearer(fx, self._first_unsmeared_bin,
[89f3b66]272                              self._last_unsmeared_bin)
[d5b488b]273        ## Sanity check
[89f3b66]274        if numpy.size(self.dy) != numpy.size(fx):
275            msg = "FitData1D: invalid error array "
[1f9f3c8a]276            msg += "%d <> %d" % (numpy.shape(self.dy), numpy.size(fx))
277            raise RuntimeError, msg
[425e49ca]278        return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx]
[444c900e]279           
[7d0c1a8]280    def residuals_deriv(self, model, pars=[]):
[1f9f3c8a]281        """
282            :return: residuals derivatives .
283           
284            :note: in this case just return empty array
[7d0c1a8]285        """
286        return []
287   
[1f9f3c8a]288   
[1e3169c]289class FitData2D(Data2D):
[1f9f3c8a]290    """
291        Wrapper class  for SANS data
292    """
[89f3b66]293    def __init__(self, sans_data2d, data=None, err_data=None):
[c4d6900]294        Data2D.__init__(self, data=data, err_data=err_data)
[7d0c1a8]295        """
[1f9f3c8a]296            Data can be initital with a data (sans plottable)
297            or with vectors.
[7d0c1a8]298        """
[89f3b66]299        self.res_err_image = []
[444c900e]300        self.idx = []
[89f3b66]301        self.qmin = None
302        self.qmax = None
[f72333f]303        self.smearer = None
[c4d6900]304        self.radius = 0
305        self.res_err_data = []
[634ca14]306        self.sans_data = sans_data2d
[89f3b66]307        self.set_data(sans_data2d)
[f72333f]308
[89f3b66]309    def set_data(self, sans_data2d, qmin=None, qmax=None):
[1e3169c]310        """
[1f9f3c8a]311            Determine the correct qx_data and qy_data within range to fit
[1e3169c]312        """
[89f3b66]313        self.data = sans_data2d.data
[83195f7]314        self.err_data = sans_data2d.err_data
315        self.qx_data = sans_data2d.qx_data
316        self.qy_data = sans_data2d.qy_data
[89f3b66]317        self.mask = sans_data2d.mask
[83195f7]318
319        x_max = max(math.fabs(sans_data2d.xmin), math.fabs(sans_data2d.xmax))
320        y_max = max(math.fabs(sans_data2d.ymin), math.fabs(sans_data2d.ymax))
[20d30e9]321       
322        ## fitting range
[027e8f2]323        if qmin == None:
324            self.qmin = 1e-16
325        if qmax == None:
[89f3b66]326            self.qmax = math.sqrt(x_max * x_max + y_max * y_max)
[70bf68c]327        ## new error image for fitting purpose
[89f3b66]328        if self.err_data == None or self.err_data == []:
329            self.res_err_data = numpy.ones(len(self.data))
[70bf68c]330        else:
[da58fcc]331            self.res_err_data = copy.deepcopy(self.err_data)
[9e8c150]332        #self.res_err_data[self.res_err_data==0]=1
[d8a2e31]333       
[89f3b66]334        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
[83195f7]335       
336        # Note: mask = True: for MASK while mask = False for NOT to mask
[1f9f3c8a]337        self.idx = ((self.qmin <= self.radius) &\
[89f3b66]338                            (self.radius <= self.qmax))
[444c900e]339        self.idx = (self.idx) & (self.mask)
340        self.idx = (self.idx) & (numpy.isfinite(self.data))
[0766d6d]341
[1f9f3c8a]342    def set_smearer(self, smearer):
[f72333f]343        """
[1f9f3c8a]344            Set smearer
[f72333f]345        """
346        if smearer == None:
347            return
348        self.smearer = smearer
[444c900e]349        self.smearer.set_index(self.idx)
[f72333f]350        self.smearer.get_data()
351
[c4d6900]352    def set_fit_range(self, qmin=None, qmax=None):
[1f9f3c8a]353        """
354            To set the fit range
355        """
[89f3b66]356        if qmin == 0.0:
[773806e]357            self.qmin = 1e-16
[1f9f3c8a]358        elif qmin != None:
359            self.qmin = qmin
[89f3b66]360        if qmax != None:
[1f9f3c8a]361            self.qmax = qmax
[89f3b66]362        self.radius = numpy.sqrt(self.qx_data**2 + self.qy_data**2)
[1f9f3c8a]363        self.idx = ((self.qmin <= self.radius) &\
[89f3b66]364                            (self.radius <= self.qmax))
[1f9f3c8a]365        self.idx = (self.idx) & (self.mask)
[444c900e]366        self.idx = (self.idx) & (numpy.isfinite(self.data))
367        self.idx = (self.idx) & (self.res_err_data != 0)
[0766d6d]368
[c4d6900]369    def get_fit_range(self):
[7d0c1a8]370        """
[aa36f96]371        return the range of data.x to fit
[7d0c1a8]372        """
[20d30e9]373        return self.qmin, self.qmax
[7d0c1a8]374     
[1f9f3c8a]375    def residuals(self, fn):
[83195f7]376        """
[aa36f96]377        return the residuals
[1f9f3c8a]378        """
[f72333f]379        if self.smearer != None:
[444c900e]380            fn.set_index(self.idx)
[f72333f]381            # Get necessary data from self.data and set the data for smearing
382            fn.get_data()
383
[1f9f3c8a]384            gn = fn.get_value()
[f72333f]385        else:
[444c900e]386            gn = fn([self.qx_data[self.idx],
387                     self.qy_data[self.idx]])
[83195f7]388        # use only the data point within ROI range
[1f9f3c8a]389        res = (self.data[self.idx] - gn) / self.res_err_data[self.idx]
[0766d6d]390
[1f9f3c8a]391        return res, gn
[0e51519]392       
[7d0c1a8]393    def residuals_deriv(self, model, pars=[]):
[1f9f3c8a]394        """
[aa36f96]395        :return: residuals derivatives .
396       
397        :note: in this case just return empty array
398       
[7d0c1a8]399        """
400        return []
[48882d1]401   
[1f9f3c8a]402   
[4bd557d]403class FitAbort(Exception):
404    """
[aa36f96]405    Exception raise to stop the fit
[4bd557d]406    """
[1ab9dc1]407    #pass
408    #print"Creating fit abort Exception"
[4bd557d]409
410
[70bf68c]411class SansAssembly:
[ca6d914]412    """
[aa36f96]413    Sans Assembly class a class wrapper to be call in optimizer.leastsq method
[ca6d914]414    """
[1f9f3c8a]415    def __init__(self, paramlist, model=None, data=None, fitresult=None,
[ba7dceb]416                 handler=None, curr_thread=None, msg_q=None):
[ca6d914]417        """
[aa36f96]418        :param Model: the model wrapper fro sans -model
419        :param Data: the data wrapper for sans data
420       
[ca6d914]421        """
[e0072082]422        self.model = model
[1f9f3c8a]423        self.data = data
[e0072082]424        self.paramlist = paramlist
[ba7dceb]425        self.msg_q = msg_q
[e0072082]426        self.curr_thread = curr_thread
427        self.handler = handler
428        self.fitresult = fitresult
429        self.res = []
[4b5bd73]430        self.true_res = []
[e0072082]431        self.func_name = "Functor"
[425e49ca]432        self.theory = None
[e0072082]433       
[c4d6900]434    def chisq(self):
[48882d1]435        """
[aa36f96]436        Calculates chi^2
437       
438        :param params: list of parameter values
439       
440        :return: chi^2
441       
[48882d1]442        """
[1f9f3c8a]443        total = 0
[4b5bd73]444        for item in self.true_res:
[1f9f3c8a]445            total += item * item
[4b5bd73]446        if len(self.true_res) == 0:
[4bd557d]447            return None
[1f9f3c8a]448        return total / len(self.true_res)
[20d30e9]449   
[c4d6900]450    def __call__(self, params):
[ca6d914]451        """
[1f9f3c8a]452            Compute residuals
453            :param params: value of parameters to fit
454        """
[4b5bd73]455        #import thread
456        self.model.set_params(self.paramlist, params)
[0766d6d]457        #print "params", params
[5722d66]458        self.true_res, theory = self.data.residuals(self.model.eval)
459        self.theory = copy.deepcopy(theory)
[4b5bd73]460        # check parameters range
461        if self.check_param_range():
462            # if the param value is outside of the bound
463            # just silent return res = inf
464            return self.res
[1f9f3c8a]465        self.res = self.true_res
[ba7dceb]466       
467        if self.fitresult is not None:
[e0072082]468            self.fitresult.set_model(model=self.model)
[444c900e]469            self.fitresult.residuals = self.true_res
[bd7a426]470            self.fitresult.iterations += 1
[444c900e]471            self.fitresult.theory = theory
[ba7dceb]472           
[4b5bd73]473            #fitness = self.chisq(params=params)
[c4d6900]474            fitness = self.chisq()
[511c6810]475            self.fitresult.pvec = params
[90c9cdf]476            self.fitresult.set_fitness(fitness=fitness)
[ba7dceb]477            if self.msg_q is not None:
478                self.msg_q.put(self.fitresult)
479               
480            if self.handler is not None:
481                self.handler.set_result(result=self.fitresult)
482                self.handler.update_fit()
[4b5bd73]483
[1f9f3c8a]484            if self.curr_thread != None:
[d5f0f5e3]485                try:
[078f2f2]486                    self.curr_thread.isquit()
487                except:
[986da97]488                    #msg = "Fitting: Terminated...       Note: Forcing to stop "
489                    #msg += "fitting may cause a 'Functor error message' "
490                    #msg += "being recorded in the log file....."
491                    #self.handler.stop(msg)
[1ab9dc1]492                    raise
[12cd4ec]493         
[48882d1]494        return self.res
495   
[4b5bd73]496    def check_param_range(self):
497        """
498        Check the lower and upper bound of the parameter value
499        and set res to the inf if the value is outside of the
500        range
501        :limitation: the initial values must be within range.
502        """
503
[bdc25e2]504        #time.sleep(0.01)
[4b5bd73]505        is_outofbound = False
506        # loop through the fit parameters
507        for p in self.model.parameterset:
508            param_name = p.get_name()
509            if param_name in self.paramlist:
510               
511                # if the range was defined, check the range
512                if numpy.isfinite(p.range[0]):
513                    if p.value == 0:
514                        # This value works on Scipy
515                        # Do not change numbers below
[1f9f3c8a]516                        value = _SMALLVALUE
[4b5bd73]517                    else:
518                        value = p.value
519                    # For leastsq, it needs a bit step back from the boundary
[1f9f3c8a]520                    val = p.range[0] - value * _SMALLVALUE
521                    if p.value < val:
[4b5bd73]522                        self.res *= 1e+6
523                       
524                        is_outofbound = True
525                        break
526                if numpy.isfinite(p.range[1]):
527                    # This value works on Scipy
528                    # Do not change numbers below
529                    if p.value == 0:
[1f9f3c8a]530                        value = _SMALLVALUE
[4b5bd73]531                    else:
532                        value = p.value
533                    # For leastsq, it needs a bit step back from the boundary
[1f9f3c8a]534                    val = p.range[1] + value * _SMALLVALUE
[4b5bd73]535                    if p.value > val:
536                        self.res *= 1e+6
537                        is_outofbound = True
538                        break
539
540        return is_outofbound
541   
542   
[4c718654]543class FitEngine:
[ee5b04c]544    def __init__(self):
[ca6d914]545        """
[aa36f96]546        Base class for scipy and park fit engine
[ca6d914]547        """
548        #List of parameter names to fit
[b2f25dc5]549        self.param_list = []
[ca6d914]550        #Dictionnary of fitArrange element (fit problems)
[b2f25dc5]551        self.fit_arrange_dict = {}
[06e7c26]552        self.fitter_id = None
[7db52f1]553       
[1f9f3c8a]554    def set_model(self, model, id, pars=[], constraints=[], data=None):
[4c718654]555        """
[c4d6900]556        set a model on a given  in the fit engine.
[aa36f96]557       
558        :param model: sans.models type
[c4d6900]559        :param : is the key of the fitArrange dictionary where model is
[aa36f96]560                saved as a value
561        :param pars: the list of parameters to fit
562        :param constraints: list of
563            tuple (name of parameter, value of parameters)
564            the value of parameter must be a string to constraint 2 different
565            parameters.
566            Example: 
567            we want to fit 2 model M1 and M2 both have parameters A and B.
568            constraints can be:
569             constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]
570           
571             
572        :note: pars must contains only name of existing model's parameters
573       
[ca6d914]574        """
[fd6b789]575        if model == None:
576            raise ValueError, "AbstractFitEngine: Need to set model to fit"
[393f0f3]577       
[89f3b66]578        new_model = model
[393f0f3]579        if not issubclass(model.__class__, Model):
[1cff677]580            new_model = Model(model, data)
[fd6b789]581       
[89f3b66]582        if len(constraints) > 0:
[fd6b789]583            for constraint in constraints:
584                name, value = constraint
585                try:
[89f3b66]586                    new_model.parameterset[str(name)].set(str(value))
[fd6b789]587                except:
[89f3b66]588                    msg = "Fit Engine: Error occurs when setting the constraint"
[c4d6900]589                    msg += " %s for parameter %s " % (value, name)
[fd6b789]590                    raise ValueError, msg
591               
[89f3b66]592        if len(pars) > 0:
593            temp = []
[fd6b789]594            for item in pars:
595                if item in new_model.model.getParamList():
596                    temp.append(item)
[b2f25dc5]597                    self.param_list.append(item)
[fd6b789]598                else:
599                   
[89f3b66]600                    msg = "wrong parameter %s used" % str(item)
601                    msg += "to set model %s. Choose" % str(new_model.model.name)
602                    msg += "parameter name within %s" % \
603                                str(new_model.model.getParamList())
604                    raise ValueError, msg
[fd6b789]605             
[c4d6900]606            #A fitArrange is already created but contains data_list only at id
607            if self.fit_arrange_dict.has_key(id):
608                self.fit_arrange_dict[id].set_model(new_model)
609                self.fit_arrange_dict[id].pars = pars
[6831a99]610            else:
[c4d6900]611            #no fitArrange object has been create with this id
[48882d1]612                fitproblem = FitArrange()
[fd6b789]613                fitproblem.set_model(new_model)
[89f3b66]614                fitproblem.pars = pars
[c4d6900]615                self.fit_arrange_dict[id] = fitproblem
[7db52f1]616                vals = []
617                for name in pars:
618                    vals.append(new_model.model.getParam(name))
619                self.fit_arrange_dict[id].vals = vals
[d4b0687]620        else:
[6831a99]621            raise ValueError, "park_integration:missing parameters"
[48882d1]622   
[c4d6900]623    def set_data(self, data, id, smearer=None, qmin=None, qmax=None):
[1f9f3c8a]624        """
[aa36f96]625        Receives plottable, creates a list of data to fit,set data
[1f9f3c8a]626        in a FitArrange object and adds that object in a dictionary
[c4d6900]627        with key id.
[aa36f96]628       
629        :param data: data added
[c4d6900]630        :param id: unique key corresponding to a fitArrange object with data
[ca6d914]631        """
[89f3b66]632        if data.__class__.__name__ == 'Data2D':
633            fitdata = FitData2D(sans_data2d=data, data=data.data,
634                                 err_data=data.err_data)
[f8ce013]635        else:
[1f9f3c8a]636            fitdata = FitData1D(x=data.x, y=data.y,
[89f3b66]637                                 dx=data.dx, dy=data.dy, smearer=smearer)
[634ca14]638        fitdata.sans_data = data
[393f0f3]639       
[c4d6900]640        fitdata.set_fit_range(qmin=qmin, qmax=qmax)
641        #A fitArrange is already created but contains model only at id
[1f9f3c8a]642        if id in self.fit_arrange_dict:
[c4d6900]643            self.fit_arrange_dict[id].add_data(fitdata)
[d4b0687]644        else:
[c4d6900]645        #no fitArrange object has been create with this id
[89f3b66]646            fitproblem = FitArrange()
[f8ce013]647            fitproblem.add_data(fitdata)
[1f9f3c8a]648            self.fit_arrange_dict[id] = fitproblem
[20d30e9]649   
[c4d6900]650    def get_model(self, id):
[1f9f3c8a]651        """
[c4d6900]652        :param id: id is key in the dictionary containing the model to return
[aa36f96]653       
[1f9f3c8a]654        :return:  a model at this id or None if no FitArrange element was
[c4d6900]655            created with this id
[d4b0687]656        """
[1f9f3c8a]657        if id in self.fit_arrange_dict:
[c4d6900]658            return self.fit_arrange_dict[id].get_model()
[d4b0687]659        else:
660            return None
661   
[c4d6900]662    def remove_fit_problem(self, id):
663        """remove   fitarrange in id"""
[1f9f3c8a]664        if id in self.fit_arrange_dict:
[c4d6900]665            del self.fit_arrange_dict[id]
[a9e04aa]666           
[c4d6900]667    def select_problem_for_fit(self, id, value):
[a9e04aa]668        """
[c4d6900]669        select a couple of model and data at the id position in dictionary
[aa36f96]670        and set in self.selected value to value
671       
[1f9f3c8a]672        :param value: the value to allow fitting.
[aa36f96]673                can only have the value one or zero
[a9e04aa]674        """
[1f9f3c8a]675        if id in self.fit_arrange_dict:
[c4d6900]676            self.fit_arrange_dict[id].set_to_fit(value)
[eef2e0ed]677             
[c4d6900]678    def get_problem_to_fit(self, id):
[a9e04aa]679        """
[c4d6900]680        return the self.selected value of the fit problem of id
[aa36f96]681       
[c4d6900]682        :param id: the id of the problem
[a9e04aa]683        """
[1f9f3c8a]684        if id in self.fit_arrange_dict:
[c4d6900]685            self.fit_arrange_dict[id].get_to_fit()
[4c718654]686   
[1f9f3c8a]687   
[d4b0687]688class FitArrange:
689    def __init__(self):
690        """
[aa36f96]691        Class FitArrange contains a set of data for a given model
692        to perform the Fit.FitArrange must contain exactly one model
693        and at least one data for the fit to be performed.
694       
695        model: the model selected by the user
696        Ldata: a list of data what the user wants to fit
[d4b0687]697           
698        """
699        self.model = None
[c4d6900]700        self.data_list = []
[89f3b66]701        self.pars = []
[7db52f1]702        self.vals = []
[a9e04aa]703        self.selected = 0
[d4b0687]704       
[89f3b66]705    def set_model(self, model):
[1f9f3c8a]706        """
[aa36f96]707        set_model save a copy of the model
708       
709        :param model: the model being set
[d4b0687]710        """
711        self.model = model
712       
[89f3b66]713    def add_data(self, data):
[1f9f3c8a]714        """
[c4d6900]715        add_data fill a self.data_list with data to fit
[aa36f96]716       
[1f9f3c8a]717        :param data: Data to add in the list
[d4b0687]718        """
[c4d6900]719        if not data in self.data_list:
720            self.data_list.append(data)
[d4b0687]721           
722    def get_model(self):
[aa36f96]723        """
[1f9f3c8a]724        :return: saved model
725        """
726        return self.model
[d4b0687]727     
728    def get_data(self):
[1f9f3c8a]729        """
[c4d6900]730        :return: list of data data_list
[aa36f96]731        """
[1f9f3c8a]732        return self.data_list[0]
[d4b0687]733     
[89f3b66]734    def remove_data(self, data):
[d4b0687]735        """
[aa36f96]736        Remove one element from the list
737       
[c4d6900]738        :param data: Data to remove from data_list
[d4b0687]739        """
[c4d6900]740        if data in self.data_list:
741            self.data_list.remove(data)
[aa36f96]742           
[1f9f3c8a]743    def set_to_fit(self, value=0):
[a9e04aa]744        """
[aa36f96]745        set self.selected to 0 or 1  for other values raise an exception
746       
747        :param value: integer between 0 or 1
[a9e04aa]748        """
[89f3b66]749        self.selected = value
[a9e04aa]750       
751    def get_to_fit(self):
752        """
[aa36f96]753        return self.selected value
[a9e04aa]754        """
755        return self.selected
[444c900e]756   
757   
758IS_MAC = True
759if sys.platform.count("win32") > 0:
760    IS_MAC = False
[1f9f3c8a]761
762
[444c900e]763class FResult(object):
764    """
765    Storing fit result
766    """
767    def __init__(self, model=None, param_list=None, data=None):
768        self.calls = None
[06e7c26]769        self.pars = []
[444c900e]770        self.fitness = None
771        self.chisqr = None
772        self.pvec = []
773        self.cov = []
774        self.info = None
775        self.mesg = None
776        self.success = None
777        self.stderr = None
778        self.residuals = []
779        self.index = []
780        self.parameters = None
781        self.is_mac = IS_MAC
782        self.model = model
783        self.data = data
784        self.theory = []
785        self.param_list = param_list
786        self.iterations = 0
787        self.inputs = []
[06e7c26]788        self.fitter_id = None
[444c900e]789        if self.model is not None and self.data is not None:
790            self.inputs = [(self.model, self.data)]
791     
792    def set_model(self, model):
793        """
794        """
795        self.model = model
796       
797    def set_fitness(self, fitness):
798        """
799        """
800        self.fitness = fitness
801       
802    def __str__(self):
803        """
804        """
805        if self.pvec == None and self.model is None and self.param_list is None:
806            return "No results"
807        n = len(self.model.parameterset)
[bd7a426]808       
[444c900e]809        result_param = zip(xrange(n), self.model.parameterset)
810        msg1 = ["[Iteration #: %s ]" % self.iterations]
811        msg3 = ["=== goodness of fit: %s ===" % (str(self.fitness))]
812        if not self.is_mac:
813            msg2 = ["P%-3d  %s......|.....%s" % \
814                (p[0], p[1], p[1].value)\
815                  for p in result_param if p[1].name in self.param_list]
[1f9f3c8a]816            msg = msg1 + msg3 + msg2
[444c900e]817        else:
818            msg = msg1 + msg3
819        msg = "\n".join(msg)
820        return msg
821   
822    def print_summary(self):
823        """
824        """
[1f9f3c8a]825        print self
Note: See TracBrowser for help on using the repository browser.