source: sasview/park_integration/ScipyFitting.py @ 89f3b66

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 89f3b66 was 89f3b66, checked in by Gervaise Alina <gervyh@…>, 14 years ago

working pylint

  • Property mode set to 100644
File size: 5.4 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10from scipy import optimize
11
12from AbstractFitEngine import FitEngine, SansAssembly, FitAbort
13
14class fitresult(object):
15    """
16    Storing fit result
17    """
18    def __init__(self, model=None, paramList=None):
19        self.calls = None
20        self.fitness = None
21        self.chisqr = None
22        self.pvec = None
23        self.cov = None
24        self.info = None
25        self.mesg = None
26        self.success = None
27        self.stderr = None
28        self.parameters = None
29        self.model = model
30        self.paramList = paramList
31     
32    def set_model(self, model):
33        """
34        """
35        self.model = model
36       
37    def set_fitness(self, fitness):
38        """
39        """
40        self.fitness = fitness
41       
42    def __str__(self):
43        """
44        """
45        if self.pvec == None and self.model is None and self.paramList is None:
46            return "No results"
47        n = len(self.model.parameterset)
48
49        result_param = zip(xrange(n), self.model.parameterset)
50        L = ["P%-3d  %s......|.....%s"%(p[0], p[1], p[1].value)\
51              for p in result_param if p[1].name in self.paramList]
52        L.append("=== goodness of fit: %s" % (str(self.fitness)))
53        return "\n".join(L)
54   
55    def print_summary(self):
56        """
57        """
58        print self   
59
60class ScipyFit(FitEngine):
61    """
62    ScipyFit performs the Fit.This class can be used as follow:
63    #Do the fit SCIPY
64    create an engine: engine = ScipyFit()
65    Use data must be of type plottable
66    Use a sans model
67   
68    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
69    is saved in FitArrange object.
70    engine.set_data(data,Uid)
71   
72    Set model parameter "M1"= model.name add {model.parameter.name:value}.
73   
74    :note: Set_param() if used must always preceded set_model()
75         for the fit to be performed.In case of Scipyfit set_param is called in
76         fit () automatically.
77   
78    engine.set_param( model,"M1", {'A':2,'B':4})
79   
80    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
81    is save in FitArrange object.
82    engine.set_model(model,Uid)
83   
84    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
85    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
86    """
87    def __init__(self):
88        """
89        Creates a dictionary (self.fitArrangeDict={})of FitArrange elements
90        with Uid as keys
91        """
92        self.fitArrangeDict = {}
93        self.paramList = []
94    #def fit(self, *args, **kw):
95    #    return profile(self._fit, *args, **kw)
96
97    def fit(self, q=None, handler=None, curr_thread=None):
98        """
99        """
100        fitproblem = []
101        for id, fproblem in self.fitArrangeDict.iteritems():
102            if fproblem.get_to_fit() == 1:
103                fitproblem.append(fproblem)
104        if len(fitproblem) > 1 : 
105            msg = "Scipy can't fit more than a single fit problem at a time."
106            raise RuntimeError, msg
107            return
108        elif len(fitproblem) == 0 : 
109            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
110            return
111   
112        listdata = []
113        model = fitproblem[0].get_model()
114        listdata = fitproblem[0].get_data()
115        # Concatenate dList set (contains one or more data)before fitting
116        data = listdata
117        self.curr_thread = curr_thread
118        result = fitresult(model=model, paramList=self.paramList)
119        if handler is not None:
120            handler.set_result(result=result)
121        #try:
122        functor = SansAssembly(self.paramList, model, data, handler=handler,
123                         fitresult=result, curr_thread= self.curr_thread)
124       
125       
126        out, cov_x, info, mesg, success = optimize.leastsq(functor,
127                                                model.getParams(self.paramList),
128                                                    full_output=1, warning=True)
129       
130        chisqr = functor.chisq(out)
131       
132        if cov_x is not None and numpy.isfinite(cov_x).all():
133            stderr = numpy.sqrt(numpy.diag(cov_x))
134        else:
135            stderr = None
136        if not (numpy.isnan(out).any()) or ( cov_x !=None) :
137                result.fitness = chisqr
138                result.stderr  = stderr
139                result.pvec = out
140                result.success = success
141                #print result
142                if q is not  None:
143                    #print "went here"
144                    q.put(result)
145                    #print "get q scipy fit enfine",q.get()
146                    return q
147                return result
148        else: 
149            raise ValueError, "SVD did not converge" + str(success)
150   
151
152
153def profile(fn, *args, **kw):
154    import cProfile, pstats, os
155    global call_result
156    def call():
157        global call_result
158        call_result = fn(*args, **kw)
159    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
160    stats = pstats.Stats('profile.out')
161    #stats.sort_stats('time')
162    stats.sort_stats('calls')
163    stats.print_stats()
164    os.unlink('profile.out')
165    return call_result
166
167     
Note: See TracBrowser for help on using the repository browser.