source: sasview/park_integration/ScipyFitting.py @ 162d7a2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 162d7a2 was aa36f96, checked in by Gervaise Alina <gervyh@…>, 15 years ago

working on documentation

  • Property mode set to 100644
File size: 5.4 KB
RevLine 
[aa36f96]1
2
[792db7d5]3"""
[aa36f96]4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
[792db7d5]7"""
[61cb28d]8
[88b5e83]9import numpy 
[7705306]10from scipy import optimize
11
[342d9197]12from AbstractFitEngine import FitEngine, SansAssembly,FitAbort
[61cb28d]13
[e0072082]14class fitresult(object):
[48882d1]15    """
[aa36f96]16    Storing fit result
[48882d1]17    """
[e0072082]18    def __init__(self, model=None, paramList=None):
19        self.calls     = None
20        self.fitness   = None
21        self.chisqr    = None
22        self.pvec      = None
23        self.cov       = None
24        self.info      = None
25        self.mesg      = None
26        self.success   = None
27        self.stderr    = None
28        self.parameters = None
29        self.model = model
30        self.paramList = paramList
31     
32    def set_model(self, model):
[aa36f96]33        """
34        """
[e0072082]35        self.model = model
36       
[90c9cdf]37    def set_fitness(self, fitness):
[aa36f96]38        """
39        """
[90c9cdf]40        self.fitness = fitness
41       
[e0072082]42    def __str__(self):
[aa36f96]43        """
44        """
[e0072082]45        if self.pvec == None and self.model is None and self.paramList is None:
46            return "No results"
47        n = len(self.model.parameterset)
48
49        result_param = zip(xrange(n), self.model.parameterset)
50        L = ["P%-3d  %s......|.....%s"%(p[0], p[1], p[1].value) for p in result_param if p[1].name in self.paramList ]
51        L.append("=== goodness of fit: %s"%(str(self.fitness)))
52        return "\n".join(L)
[48882d1]53   
[e0072082]54    def print_summary(self):
[aa36f96]55        """
56        """
[e0072082]57        print self   
[88b5e83]58
[4c718654]59class ScipyFit(FitEngine):
[7705306]60    """
[aa36f96]61    ScipyFit performs the Fit.This class can be used as follow:
62    #Do the fit SCIPY
63    create an engine: engine = ScipyFit()
64    Use data must be of type plottable
65    Use a sans model
66   
67    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
68    is saved in FitArrange object.
69    engine.set_data(data,Uid)
70   
71    Set model parameter "M1"= model.name add {model.parameter.name:value}.
72   
73    :note: Set_param() if used must always preceded set_model()
74         for the fit to be performed.In case of Scipyfit set_param is called in
75         fit () automatically.
76   
77    engine.set_param( model,"M1", {'A':2,'B':4})
78   
79    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
80    is save in FitArrange object.
81    engine.set_model(model,Uid)
82   
83    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
84    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]85    """
[792db7d5]86    def __init__(self):
87        """
[aa36f96]88        Creates a dictionary (self.fitArrangeDict={})of FitArrange elements
89        with Uid as keys
[792db7d5]90        """
[393f0f3]91        self.fitArrangeDict={}
92        self.paramList=[]
[d9dc518]93    #def fit(self, *args, **kw):
94    #    return profile(self._fit, *args, **kw)
[393f0f3]95
[e0072082]96    def fit(self, q=None, handler=None, curr_thread=None):
[aa36f96]97        """
98        """
[393f0f3]99        fitproblem=[]
100        for id ,fproblem in self.fitArrangeDict.iteritems():
101            if fproblem.get_to_fit()==1:
102                fitproblem.append(fproblem)
103        if len(fitproblem)>1 : 
[e0072082]104            msg = "Scipy can't fit more than a single fit problem at a time."
105            raise RuntimeError, msg
[a9e04aa]106            return
[393f0f3]107        elif len(fitproblem)==0 : 
[a9e04aa]108            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
109            return
110   
[393f0f3]111        listdata=[]
112        model = fitproblem[0].get_model()
113        listdata = fitproblem[0].get_data()
[792db7d5]114        # Concatenate dList set (contains one or more data)before fitting
[e0072082]115        data = listdata
[393f0f3]116        self.curr_thread= curr_thread
[e0072082]117        result = fitresult(model=model, paramList=self.paramList)
118        if handler is not None:
119            handler.set_result(result=result)
[fd6b789]120        #try:
[e0072082]121        functor = SansAssembly(self.paramList, model, data, handler=handler,
122                                fitresult=result,curr_thread= self.curr_thread)
123       
124       
125        out, cov_x, info, mesg, success = optimize.leastsq(functor,
126                                                model.getParams(self.paramList),
127                                                    full_output=1, warning=True)
[fd6b789]128       
129        chisqr = functor.chisq(out)
[e71440c]130       
[fd6b789]131        if cov_x is not None and numpy.isfinite(cov_x).all():
132            stderr = numpy.sqrt(numpy.diag(cov_x))
133        else:
[e0072082]134            stderr = None
[fd6b789]135        if not (numpy.isnan(out).any()) or ( cov_x !=None) :
136                result.fitness = chisqr
137                result.stderr  = stderr
138                result.pvec = out
139                result.success = success
[59d8b56]140                #print result
[e0072082]141                if q is not  None:
142                    #print "went here"
[fd6b789]143                    q.put(result)
[e0072082]144                    #print "get q scipy fit enfine",q.get()
[fd6b789]145                    return q
146                return result
147        else: 
[393f0f3]148            raise ValueError, "SVD did not converge"+str(success)
[e0072082]149   
[d8a2e31]150
151
[9c648c7]152def profile(fn, *args, **kw):
153    import cProfile, pstats, os
154    global call_result
155    def call():
156        global call_result
157        call_result = fn(*args, **kw)
158    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
159    stats = pstats.Stats('profile.out')
160    #stats.sort_stats('time')
161    stats.sort_stats('calls')
162    stats.print_stats()
163    os.unlink('profile.out')
164    return call_result
165
[48882d1]166     
Note: See TracBrowser for help on using the repository browser.