source: sasview/park_integration/ScipyFitting.py @ 3a848b2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 3a848b2 was 90c9cdf, checked in by Gervaise Alina <gervyh@…>, 15 years ago

update chisqr for scipyfit

  • Property mode set to 100644
File size: 5.4 KB
Line 
1"""
2    @organization: ScipyFitting module contains FitArrange , ScipyFit,
3    Parameter classes.All listed classes work together to perform a
4    simple fit with scipy optimizer.
5"""
6
7import numpy 
8from scipy import optimize
9
10from AbstractFitEngine import FitEngine, SansAssembly,FitAbort
11
12class fitresult(object):
13    """
14        Storing fit result
15    """
16    def __init__(self, model=None, paramList=None):
17        self.calls     = None
18        self.fitness   = None
19        self.chisqr    = None
20        self.pvec      = None
21        self.cov       = None
22        self.info      = None
23        self.mesg      = None
24        self.success   = None
25        self.stderr    = None
26        self.parameters = None
27        self.model = model
28        self.paramList = paramList
29     
30    def set_model(self, model):
31        self.model = model
32       
33    def set_fitness(self, fitness):
34        self.fitness = fitness
35       
36    def __str__(self):
37        if self.pvec == None and self.model is None and self.paramList is None:
38            return "No results"
39        n = len(self.model.parameterset)
40
41        result_param = zip(xrange(n), self.model.parameterset)
42        L = ["P%-3d  %s......|.....%s"%(p[0], p[1], p[1].value) for p in result_param if p[1].name in self.paramList ]
43        L.append("=== goodness of fit: %s"%(str(self.fitness)))
44        return "\n".join(L)
45   
46    def print_summary(self):
47        print self   
48
49class ScipyFit(FitEngine):
50    """
51        ScipyFit performs the Fit.This class can be used as follow:
52        #Do the fit SCIPY
53        create an engine: engine = ScipyFit()
54        Use data must be of type plottable
55        Use a sans model
56       
57        Add data with a dictionnary of FitArrangeDict where Uid is a key and data
58        is saved in FitArrange object.
59        engine.set_data(data,Uid)
60       
61        Set model parameter "M1"= model.name add {model.parameter.name:value}.
62        @note: Set_param() if used must always preceded set_model()
63             for the fit to be performed.In case of Scipyfit set_param is called in
64             fit () automatically.
65        engine.set_param( model,"M1", {'A':2,'B':4})
66       
67        Add model with a dictionnary of FitArrangeDict{} where Uid is a key and model
68        is save in FitArrange object.
69        engine.set_model(model,Uid)
70       
71        engine.fit return chisqr,[model.parameter 1,2,..],[[err1....][..err2...]]
72        chisqr1, out1, cov1=engine.fit({model.parameter.name:value},qmin,qmax)
73    """
74    def __init__(self):
75        """
76            Creates a dictionary (self.fitArrangeDict={})of FitArrange elements
77            with Uid as keys
78        """
79        self.fitArrangeDict={}
80        self.paramList=[]
81    #def fit(self, *args, **kw):
82    #    return profile(self._fit, *args, **kw)
83
84    def fit(self, q=None, handler=None, curr_thread=None):
85       
86        fitproblem=[]
87        for id ,fproblem in self.fitArrangeDict.iteritems():
88            if fproblem.get_to_fit()==1:
89                fitproblem.append(fproblem)
90        if len(fitproblem)>1 : 
91            msg = "Scipy can't fit more than a single fit problem at a time."
92            raise RuntimeError, msg
93            return
94        elif len(fitproblem)==0 : 
95            raise RuntimeError, "No Assembly scheduled for Scipy fitting."
96            return
97   
98        listdata=[]
99        model = fitproblem[0].get_model()
100        listdata = fitproblem[0].get_data()
101        # Concatenate dList set (contains one or more data)before fitting
102        data = listdata
103        self.curr_thread= curr_thread
104        result = fitresult(model=model, paramList=self.paramList)
105        if handler is not None:
106            handler.set_result(result=result)
107        #try:
108        functor = SansAssembly(self.paramList, model, data, handler=handler,
109                                fitresult=result,curr_thread= self.curr_thread)
110       
111       
112        out, cov_x, info, mesg, success = optimize.leastsq(functor,
113                                                model.getParams(self.paramList),
114                                                    full_output=1, warning=True)
115       
116        chisqr = functor.chisq(out)
117       
118        if cov_x is not None and numpy.isfinite(cov_x).all():
119            stderr = numpy.sqrt(numpy.diag(cov_x))
120        else:
121            stderr = None
122        if not (numpy.isnan(out).any()) or ( cov_x !=None) :
123                result.fitness = chisqr
124                result.stderr  = stderr
125                result.pvec = out
126                result.success = success
127                print result
128                if q is not  None:
129                    #print "went here"
130                    q.put(result)
131                    #print "get q scipy fit enfine",q.get()
132                    return q
133                return result
134        else: 
135            raise ValueError, "SVD did not converge"+str(success)
136   
137
138
139def profile(fn, *args, **kw):
140    import cProfile, pstats, os
141    global call_result
142    def call():
143        global call_result
144        call_result = fn(*args, **kw)
145    cProfile.runctx('call()', dict(call=call), {}, 'profile.out')
146    stats = pstats.Stats('profile.out')
147    #stats.sort_stats('time')
148    stats.sort_stats('calls')
149    stats.print_stats()
150    os.unlink('profile.out')
151    return call_result
152
153     
Note: See TracBrowser for help on using the repository browser.