source: sasview/test/park_integration/test/utest_fit_line.py @ 386ffe1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 386ffe1 was 386ffe1, checked in by pkienzle, 9 years ago

remove scipy levenburg marquardt and park from ui

  • Property mode set to 100644
File size: 7.2 KB
Line 
1"""
2    Unit tests for fitting module
3    @author Gervaise Alina
4"""
5import unittest
6import math
7
8from sas.fit.AbstractFitEngine import Model, FitHandler
9from sas.dataloader.loader import Loader
10from sas.fit.Fitting import Fit
11from sas.models.LineModel import LineModel
12from sas.models.Constant import Constant
13
14class testFitModule(unittest.TestCase):
15    """ test fitting """
16
17    def test_bad_pars(self):
18        fitter = Fit('bumps')
19
20        data = Loader().load("testdata_line.txt")
21        data.name = data.filename
22        fitter.set_data(data,1)
23
24        model1  = LineModel()
25        model1.name = "M1"
26        model = Model(model1, data)
27        pars1= ['param1','param2']
28        try:
29            fitter.set_model(model,1,pars1)
30        except ValueError,exc:
31            #print "ValueError was correctly raised: "+str(msg)
32            assert str(exc).startswith('parameter param1')
33        else:
34            raise AssertionError("No error raised for fitting with wrong parameters name to fit")
35
36    def fit_single(self, fitter_name, isdream=False):
37        fitter = Fit(fitter_name)
38
39        data = Loader().load("testdata_line.txt")
40        data.name = data.filename
41        fitter.set_data(data,1)
42
43        # Receives the type of model for the fitting
44        model1  = LineModel()
45        model1.name = "M1"
46        model = Model(model1,data)
47
48        pars1= ['A','B']
49        fitter.set_model(model,1,pars1)
50        fitter.select_problem_for_fit(id=1,value=1)
51        result1, = fitter.fit(handler=FitHandler())
52
53        # The target values were generated from the following statements
54        p,s,fx = result1.pvec, result1.stderr, result1.fitness
55        #print "p0,p1,s0,s1,fx = %g, %g, %g, %g, %g"%(p[0],p[1],s[0],s[1],fx)
56        p0,p1,s0,s1,fx_ = 3.68353, 2.61004, 0.336186, 0.105244, 1.20189
57
58        if isdream:
59            # Dream is not a minimizer: just check that the fit is within
60            # uncertainty
61            self.assertTrue( abs(p[0]-p0) <= s0 )
62            self.assertTrue( abs(p[1]-p1) <= s1 )
63        else:
64            self.assertTrue( abs(p[0]-p0) <= 1e-5 )
65            self.assertTrue( abs(p[1]-p1) <= 1e-5 )
66            self.assertTrue( abs(fx-fx_) <= 1e-5 )
67
68    def fit_bumps(self, alg, **opts):
69        #Importing the Fit module
70        from bumps import fitters
71        fitters.FIT_DEFAULT = alg
72        fitters.FIT_OPTIONS[alg].options.update(opts)
73        fitters.FIT_OPTIONS[alg].options.update(monitors=[])
74        #print "fitting",alg,opts
75        #kprint "options",fitters.FIT_OPTIONS[alg].__dict__
76        self.fit_single('bumps', isdream=(alg=='dream'))
77
78    def test_bumps_de(self):
79        self.fit_bumps('de')
80
81    def test_bumps_dream(self):
82        self.fit_bumps('dream', burn=500, steps=100)
83
84    def test_bumps_amoeba(self):
85        self.fit_bumps('amoeba')
86
87    def test_bumps_newton(self):
88        self.fit_bumps('newton')
89
90    def test_bumps_lm(self):
91        self.fit_bumps('lm')
92
93    def test2(self):
94        """ fit 2 data and 2 model with no constrainst"""
95        #load data
96        l = Loader()
97        data1=l.load("testdata_line.txt")
98        data1.name = data1.filename
99     
100        data2=l.load("testdata_line1.txt")
101        data2.name = data2.filename
102     
103        #Importing the Fit module
104        fitter = Fit('bumps')
105        # Receives the type of model for the fitting
106        model11  = LineModel()
107        model11.name= "M1"
108        model22  = LineModel()
109        model11.name= "M2"
110     
111        model1 = Model(model11,data1)
112        model2 = Model(model22,data2)
113        pars1= ['A','B']
114        fitter.set_data(data1,1)
115        fitter.set_model(model1,1,pars1)
116        fitter.select_problem_for_fit(id=1,value=0)
117        fitter.set_data(data2,2)
118        fitter.set_model(model2,2,pars1)
119        fitter.select_problem_for_fit(id=2,value=0)
120
121        try: result1, = fitter.fit(handler=FitHandler())
122        except RuntimeError,msg:
123            assert str(msg)=="Nothing to fit"
124        else: raise AssertionError,"No error raised for fitting with no model"
125        fitter.select_problem_for_fit(id=1,value=1)
126        fitter.select_problem_for_fit(id=2,value=1)
127        R1,R2 = fitter.fit(handler=FitHandler())
128       
129        self.assertTrue( math.fabs(R1.pvec[0]-4)/3 <= R1.stderr[0] )
130        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3 <= R1.stderr[1] )
131        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
132       
133       
134    def test_constraints(self):
135        """ fit 2 data and 2 model with 1 constrainst"""
136        #load data
137        l = Loader()
138        data1= l.load("testdata_line.txt")
139        data1.name = data1.filename
140        data2= l.load("testdata_cst.txt")
141        data2.name = data2.filename
142       
143        # Receives the type of model for the fitting
144        model11  = LineModel()
145        model11.name= "line"
146        model11.setParam("A", 1.0)
147        model11.setParam("B",1.0)
148       
149        model22  = Constant()
150        model22.name= "cst"
151        model22.setParam("value", 1.0)
152       
153        model1 = Model(model11,data1)
154        model2 = Model(model22,data2)
155        model1.set(A=4)
156        model1.set(B=3)
157        # Constraint the constant value to be equal to parameter B (the real value is 2.5)
158        #model2.set(value='line.B')
159        pars1= ['A','B']
160        pars2= ['value']
161       
162        #Importing the Fit module
163        fitter = Fit('bumps')
164        fitter.set_data(data1,1)
165        fitter.set_model(model1,1,pars1)
166        fitter.set_data(data2,2,smearer=None)
167        fitter.set_model(model2,2,pars2,constraints=[("value","line.B")])
168        fitter.select_problem_for_fit(id=1,value=1)
169        fitter.select_problem_for_fit(id=2,value=1)
170       
171        R1,R2 = fitter.fit(handler=FitHandler())
172        self.assertTrue( math.fabs(R1.pvec[0]-4.0)/3. <= R1.stderr[0])
173        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3. <= R1.stderr[1])
174        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
175       
176       
177    def test4(self):
178        """ fit 2 data concatenates with limited range of x and  one model """
179            #load data
180        l = Loader()
181        data1 = l.load("testdata_line.txt")
182        data1.name = data1.filename
183        data2 = l.load("testdata_line1.txt")
184        data2.name = data2.filename
185
186        # Receives the type of model for the fitting
187        model1  = LineModel()
188        model1.name= "M1"
189        model1.setParam("A", 1.0)
190        model1.setParam("B",1.0)
191        model = Model(model1,data1)
192     
193        pars1= ['A','B']
194        #Importing the Fit module
195
196        fitter = Fit('bumps')
197        fitter.set_data(data1,1,qmin=0, qmax=7)
198        fitter.set_model(model,1,pars1)
199        fitter.set_data(data2,1,qmin=1,qmax=10)
200        fitter.select_problem_for_fit(id=1,value=1)
201        result2, = fitter.fit(handler=FitHandler())
202       
203        self.assert_(result2)
204        self.assertTrue( math.fabs(result2.pvec[0]-4)/3 <= result2.stderr[0] )
205        self.assertTrue( math.fabs(result2.pvec[1]-2.5)/3 <= result2.stderr[1] )
206        self.assertTrue( result2.fitness/len(data1.x) < 2)
207
208
209if __name__ == "__main__":
210    unittest.main()
211   
Note: See TracBrowser for help on using the repository browser.