source: sasview/park_integration/ScipyFitting.py @ 07c4599

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 07c4599 was c4d6900, checked in by Gervaise Alina <gervyh@…>, 14 years ago

working on pylint

  • Property mode set to 100644
File size: 5.5 KB
Line 
1
2
3"""
4ScipyFitting module contains FitArrange , ScipyFit,
5Parameter classes.All listed classes work together to perform a
6simple fit with scipy optimizer.
7"""
8
9import numpy 
10from scipy import optimize
11
12from sans.fit.AbstractFitEngine import FitEngine
13from sans.fit.AbstractFitEngine import SansAssembly
14#from sans.fit.AbstractFitEngine import FitAbort
15
16class fitresult(object):
17    """
18    Storing fit result
19    """
20    def __init__(self, model=None, param_list=None):
21        self.calls = None
22        self.fitness = None
23        self.chisqr = None
24        self.pvec = None
25        self.cov = None
26        self.info = None
27        self.mesg = None
28        self.success = None
29        self.stderr = None
30        self.parameters = None
31        self.model = model
32        self.param_list = param_list
33     
34    def set_model(self, model):
35        """
36        """
37        self.model = model
38       
39    def set_fitness(self, fitness):
40        """
41        """
42        self.fitness = fitness
43       
44    def __str__(self):
45        """
46        """
47        if self.pvec == None and self.model is None and self.param_list is None:
48            return "No results"
49        n = len(self.model.parameterset)
50
51        result_param = zip(xrange(n), self.model.parameterset)
52        msg = ["P%-3d  %s......|.....%s" % (p[0], p[1], p[1].value)\
53              for p in result_param if p[1].name in self.param_list]
54        msg.append("=== goodness of fit: %s" % (str(self.fitness)))
55        return "\n".join(msg)
56   
57    def print_summary(self):
58        """
59        """
60        print self   
61
62class ScipyFit(FitEngine):
63    """
64    ScipyFit performs the Fit.This class can be used as follow:
65    #Do the fit SCIPY
66    create an engine: engine = ScipyFit()
67    Use data must be of type plottable
68    Use a sans model
69   
70    Add data with a dictionnary of FitArrangeDict where Uid is a key and data
71    is saved in FitArrange object.
72    engine.set_data(data,Uid)
73   
74    Set model parameter "M1"= model.name add {model.parameter.name:value}.
75   
76    :note: Set_param() if used must always preceded set_model()
77         for the fit to be performed.In case of Scipyfit set_param is called in
78         fit () automatically.
79   
80    engine.set_param( model,"M1", {'A':2,'B':4})
81   
82    Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
83    is save in FitArrange object.
84    engine.set_model(model,Uid)
85   
86    engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
87    chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
88    """
89    def __init__(self):
90        """
91        Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements
92        with Uid as keys
93        """
94        FitEngine.__init__(self)
95        self.fit_arrange_dict = {}
96        self.param_list = []
97        self.curr_thread = None
98    #def fit(self, *args, **kw):
99    #    return profile(self._fit, *args, **kw)
100
101    def fit(self, q=None, handler=None, curr_thread=None):
102        """
103        """
104        fitproblem = []
105        for fproblem in self.fit_arrange_dict.itervalues():
106            if fproblem.get_to_fit() == 1:
107                fitproblem.append(fproblem)
108        if len(fitproblem) > 1 : 
109            msg = "Scipy can't fit more than a single fit problem at a time."
110            raise RuntimeError, msg
111            return
112        elif len(fitproblem) == 0 : 
113            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
114            return
115   
116        listdata = []
117        model = fitproblem[0].get_model()
118        listdata = fitproblem[0].get_data()
119        # Concatenate dList set (contains one or more data)before fitting
120        data = listdata
121        self.curr_thread = curr_thread
122        result = fitresult(model=model, param_list=self.param_list)
123        if handler is not None:
124            handler.set_result(result=result)
125        #try:
126        functor = SansAssembly(self.param_list, model, data, handler=handler,
127                         fitresult=result, curr_thread= self.curr_thread)
128        out, cov_x, _, _, success = optimize.leastsq(functor,
129                                            model.get_params(self.param_list),
130                                                    full_output=1,
131                                                    warning=True)
132       
133        #chisqr = functor.chisq(out)
134        chisqr = functor.chisq()
135        if cov_x is not None and numpy.isfinite(cov_x).all():
136            stderr = numpy.sqrt(numpy.diag(cov_x))
137        else:
138            stderr = None
139        if not (numpy.isnan(out).any()) or (cov_x != None):
140            result.fitness = chisqr
141            result.stderr  = stderr
142            result.pvec = out
143            result.success = success
144            #print result
145            if q is not None:
146                #print "went here"
147                q.put(result)
148                #print "get q scipy fit enfine",q.get()
149                return q
150            return result
151        else: 
152            raise ValueError, "SVD did not converge" + str(success)
153   
154
155
156#def profile(fn, *args, **kw):
157#    import cProfile, pstats, os
158#    global call_result
159#   def call():
160#        global call_result
161#        call_result = fn(*args, **kw)
162#    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
163#    stats = pstats.Stats('profile.out')
164#    stats.sort_stats('time')
165#    stats.sort_stats('calls')
166#    stats.print_stats()
167#    os.unlink('profile.out')
168#    return call_result
169
170     
Note: See TracBrowser for help on using the repository browser.