source: sasview/park_integration/ScipyFitting.py @ 3a848b2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3a848b2 was 90c9cdf, checked in by Gervaise Alina <gervyh@…>, 15 years ago

update chisqr for scipyfit

  • Property mode set to 100644
File size: 5.4 KB
RevLine 
[792db7d5]1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
[61cb28d]6
[88b5e83]7import numpy 
[7705306]8from scipy import optimize
9
[342d9197]10from AbstractFitEngine import FitEngine, SansAssembly,FitAbort
[61cb28d]11
[e0072082]12class fitresult(object):
[48882d1]13    """
14        Storing fit result
15    """
[e0072082]16    def __init__(self, model=None, paramList=None):
17        self.calls     = None
18        self.fitness   = None
19        self.chisqr    = None
20        self.pvec      = None
21        self.cov       = None
22        self.info      = None
23        self.mesg      = None
24        self.success   = None
25        self.stderr    = None
26        self.parameters = None
27        self.model = model
28        self.paramList = paramList
29     
30    def set_model(self, model):
31        self.model = model
32       
[90c9cdf]33    def set_fitness(self, fitness):
34        self.fitness = fitness
35       
[e0072082]36    def __str__(self):
37        if self.pvec == None and self.model is None and self.paramList is None:
38            return "No results"
39        n = len(self.model.parameterset)
40
41        result_param = zip(xrange(n), self.model.parameterset)
42        L = ["P%-3d  %s......|.....%s"%(p[0], p[1], p[1].value) for p in result_param if p[1].name in self.paramList ]
43        L.append("=== goodness of fit: %s"%(str(self.fitness)))
44        return "\n".join(L)
[48882d1]45   
[e0072082]46    def print_summary(self):
47        print self   
[88b5e83]48
[4c718654]49class ScipyFit(FitEngine):
[7705306]50    """
[792db7d5]51        ScipyFit performs the Fit.This class can be used as follow:
52        #Do the fit SCIPY
53        create an engine: engine = ScipyFit()
54        Use data must be of type plottable
55        Use a sans model
56       
[ca6d914]57        Add data with a dictionnary of FitArrangeDict where Uid is a key and data
[792db7d5]58        is saved in FitArrange object.
59        engine.set_data(data,Uid)
60       
61        Set model parameter "M1"= model.name add {model.parameter.name:value}.
62        @note: Set_param() if used must always preceded set_model()
63             for the fit to be performed.In case of Scipyfit set_param is called in
64             fit () automatically.
65        engine.set_param( model,"M1", {'A':2,'B':4})
66       
[ca6d914]67        Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
[792db7d5]68        is save in FitArrange object.
69        engine.set_model(model,Uid)
70       
71        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
72        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
[7705306]73    """
[792db7d5]74    def __init__(self):
75        """
[ca6d914]76            Creates a dictionary (self.fitArrangeDict={})of FitArrange elements
[792db7d5]77            with Uid as keys
78        """
[393f0f3]79        self.fitArrangeDict={}
80        self.paramList=[]
[d9dc518]81    #def fit(self, *args, **kw):
82    #    return profile(self._fit, *args, **kw)
[393f0f3]83
[e0072082]84    def fit(self, q=None, handler=None, curr_thread=None):
[393f0f3]85       
86        fitproblem=[]
87        for id ,fproblem in self.fitArrangeDict.iteritems():
88            if fproblem.get_to_fit()==1:
89                fitproblem.append(fproblem)
90        if len(fitproblem)>1 : 
[e0072082]91            msg = "Scipy can't fit more than a single fit problem at a time."
92            raise RuntimeError, msg
[a9e04aa]93            return
[393f0f3]94        elif len(fitproblem)==0 : 
[a9e04aa]95            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
96            return
97   
[393f0f3]98        listdata=[]
99        model = fitproblem[0].get_model()
100        listdata = fitproblem[0].get_data()
[792db7d5]101        # Concatenate dList set (contains one or more data)before fitting
[e0072082]102        data = listdata
[393f0f3]103        self.curr_thread= curr_thread
[e0072082]104        result = fitresult(model=model, paramList=self.paramList)
105        if handler is not None:
106            handler.set_result(result=result)
[fd6b789]107        #try:
[e0072082]108        functor = SansAssembly(self.paramList, model, data, handler=handler,
109                                fitresult=result,curr_thread= self.curr_thread)
110       
111       
112        out, cov_x, info, mesg, success = optimize.leastsq(functor,
113                                                model.getParams(self.paramList),
114                                                    full_output=1, warning=True)
[fd6b789]115       
116        chisqr = functor.chisq(out)
[e71440c]117       
[fd6b789]118        if cov_x is not None and numpy.isfinite(cov_x).all():
119            stderr = numpy.sqrt(numpy.diag(cov_x))
120        else:
[e0072082]121            stderr = None
[fd6b789]122        if not (numpy.isnan(out).any()) or ( cov_x !=None) :
123                result.fitness = chisqr
124                result.stderr  = stderr
125                result.pvec = out
126                result.success = success
[e0072082]127                print result
128                if q is not  None:
129                    #print "went here"
[fd6b789]130                    q.put(result)
[e0072082]131                    #print "get q scipy fit enfine",q.get()
[fd6b789]132                    return q
133                return result
134        else: 
[393f0f3]135            raise ValueError, "SVD did not converge"+str(success)
[e0072082]136   
[d8a2e31]137
138
[9c648c7]139def profile(fn, *args, **kw):
140    import cProfile, pstats, os
141    global call_result
142    def call():
143        global call_result
144        call_result = fn(*args, **kw)
145    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
146    stats = pstats.Stats('profile.out')
147    #stats.sort_stats('time')
148    stats.sort_stats('calls')
149    stats.print_stats()
150    os.unlink('profile.out')
151    return call_result
152
[48882d1]153     
Note: See TracBrowser for help on using the repository browser.