source: sasview/park_integration/ScipyFitting.py @ e0072082

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e0072082 was e0072082, checked in by Gervaise Alina <gervyh@…>, 14 years ago

displaying result on status bar for single fit

  • Property mode set to 100644
File size: 5.3 KB
Line 
1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
6
7import numpy 
8from scipy import optimize
9
10from AbstractFitEngine import FitEngine, SansAssembly,FitAbort
11
12class fitresult(object):
13    """
14        Storing fit result
15    """
16    def __init__(self, model=None, paramList=None):
17        self.calls     = None
18        self.fitness   = None
19        self.chisqr    = None
20        self.pvec      = None
21        self.cov       = None
22        self.info      = None
23        self.mesg      = None
24        self.success   = None
25        self.stderr    = None
26        self.parameters = None
27        self.model = model
28        self.paramList = paramList
29     
30    def set_model(self, model):
31        self.model = model
32       
33    def __str__(self):
34        if self.pvec == None and self.model is None and self.paramList is None:
35            return "No results"
36        n = len(self.model.parameterset)
37
38        result_param = zip(xrange(n), self.model.parameterset)
39        L = ["P%-3d  %s......|.....%s"%(p[0], p[1], p[1].value) for p in result_param if p[1].name in self.paramList ]
40        L.append("=== goodness of fit: %s"%(str(self.fitness)))
41        return "\n".join(L)
42   
43    def print_summary(self):
44        print self   
45
46class ScipyFit(FitEngine):
47    """
48        ScipyFit performs the Fit.This class can be used as follow:
49        #Do the fit SCIPY
50        create an engine: engine = ScipyFit()
51        Use data must be of type plottable
52        Use a sans model
53       
54        Add data with a dictionnary of FitArrangeDict where Uid is a key and data
55        is saved in FitArrange object.
56        engine.set_data(data,Uid)
57       
58        Set model parameter "M1"= model.name add {model.parameter.name:value}.
59        @note: Set_param() if used must always preceded set_model()
60             for the fit to be performed.In case of Scipyfit set_param is called in
61             fit () automatically.
62        engine.set_param( model,"M1", {'A':2,'B':4})
63       
64        Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
65        is save in FitArrange object.
66        engine.set_model(model,Uid)
67       
68        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
69        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
70    """
71    def __init__(self):
72        """
73            Creates a dictionary (self.fitArrangeDict={})of FitArrange elements
74            with Uid as keys
75        """
76        self.fitArrangeDict={}
77        self.paramList=[]
78    #def fit(self, *args, **kw):
79    #    return profile(self._fit, *args, **kw)
80
81    def fit(self, q=None, handler=None, curr_thread=None):
82       
83        fitproblem=[]
84        for id ,fproblem in self.fitArrangeDict.iteritems():
85            if fproblem.get_to_fit()==1:
86                fitproblem.append(fproblem)
87        if len(fitproblem)>1 : 
88            msg = "Scipy can't fit more than a single fit problem at a time."
89            raise RuntimeError, msg
90            return
91        elif len(fitproblem)==0 : 
92            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
93            return
94   
95        listdata=[]
96        model = fitproblem[0].get_model()
97        listdata = fitproblem[0].get_data()
98        # Concatenate dList set (contains one or more data)before fitting
99        data = listdata
100        self.curr_thread= curr_thread
101        result = fitresult(model=model, paramList=self.paramList)
102        if handler is not None:
103            handler.set_result(result=result)
104        #try:
105        functor = SansAssembly(self.paramList, model, data, handler=handler,
106                                fitresult=result,curr_thread= self.curr_thread)
107       
108       
109        out, cov_x, info, mesg, success = optimize.leastsq(functor,
110                                                model.getParams(self.paramList),
111                                                    full_output=1, warning=True)
112       
113        chisqr = functor.chisq(out)
114       
115        if cov_x is not None and numpy.isfinite(cov_x).all():
116            stderr = numpy.sqrt(numpy.diag(cov_x))
117        else:
118            stderr = None
119        if not (numpy.isnan(out).any()) or ( cov_x !=None) :
120                result.fitness = chisqr
121                result.stderr  = stderr
122                result.pvec = out
123                result.success = success
124                print result
125                if q is not  None:
126                    #print "went here"
127                    q.put(result)
128                    #print "get q scipy fit enfine",q.get()
129                    return q
130                return result
131        else: 
132            raise ValueError, "SVD did not converge"+str(success)
133   
134
135
136def profile(fn, *args, **kw):
137    import cProfile, pstats, os
138    global call_result
139    def call():
140        global call_result
141        call_result = fn(*args, **kw)
142    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
143    stats = pstats.Stats('profile.out')
144    #stats.sort_stats('time')
145    stats.sort_stats('calls')
146    stats.print_stats()
147    os.unlink('profile.out')
148    return call_result
149
150     
Note: See TracBrowser for help on using the repository browser.