source: sasview/test/park_integration/test/utest_fit_line.py @ 35ec279

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 35ec279 was 35ec279, checked in by krzywon, 9 years ago

Completed the SANS to SAS conversion on the tests. src.sas is left.

  • Property mode set to 100644
File size: 9.1 KB
Line 
1"""
2    Unit tests for fitting module
3    @author Gervaise Alina
4"""
5import unittest
6import math
7
8from sas.fit.AbstractFitEngine import Model, FitHandler
9from sas.dataloader.loader import Loader
10from sas.fit.Fitting import Fit
11from sas.models.LineModel import LineModel
12from sas.models.Constant import Constant
13
14class testFitModule(unittest.TestCase):
15    """ test fitting """
16
17    def test_bad_pars(self):
18        fitter = Fit('bumps')
19
20        data = Loader().load("testdata_line.txt")
21        data.name = data.filename
22        fitter.set_data(data,1)
23
24        model1  = LineModel()
25        model1.name = "M1"
26        model = Model(model1, data)
27        pars1= ['param1','param2']
28        try:
29            fitter.set_model(model,1,pars1)
30        except ValueError,exc:
31            #print "ValueError was correctly raised: "+str(msg)
32            assert str(exc).startswith('parameter param1')
33        else:
34            raise AssertionError("No error raised for scipy fitting with wrong parameters name to fit")
35
36    def fit_single(self, fitter_name, isdream=False):
37        fitter = Fit(fitter_name)
38
39        data = Loader().load("testdata_line.txt")
40        data.name = data.filename
41        fitter.set_data(data,1)
42
43        # Receives the type of model for the fitting
44        model1  = LineModel()
45        model1.name = "M1"
46        model = Model(model1,data)
47        #fit with scipy test
48
49        pars1= ['A','B']
50        fitter.set_model(model,1,pars1)
51        fitter.select_problem_for_fit(id=1,value=1)
52        result1, = fitter.fit(handler=FitHandler())
53
54        # The target values were generated from the following statements
55        p,s,fx = result1.pvec, result1.stderr, result1.fitness
56        #print "p0,p1,s0,s1,fx = %g, %g, %g, %g, %g"%(p[0],p[1],s[0],s[1],fx)
57        p0,p1,s0,s1,fx_ = 3.68353, 2.61004, 0.336186, 0.105244, 1.20189
58
59        if isdream:
60            # Dream is not a minimizer: just check that the fit is within
61            # uncertainty
62            self.assertTrue( abs(p[0]-p0) <= s0 )
63            self.assertTrue( abs(p[1]-p1) <= s1 )
64        else:
65            self.assertTrue( abs(p[0]-p0) <= 1e-5 )
66            self.assertTrue( abs(p[1]-p1) <= 1e-5 )
67            self.assertTrue( abs(fx-fx_) <= 1e-5 )
68
69    def fit_bumps(self, alg, **opts):
70        #Importing the Fit module
71        from bumps import fitters
72        fitters.FIT_DEFAULT = alg
73        fitters.FIT_OPTIONS[alg].options.update(opts)
74        fitters.FIT_OPTIONS[alg].options.update(monitors=[])
75        #print "fitting",alg,opts
76        #kprint "options",fitters.FIT_OPTIONS[alg].__dict__
77        self.fit_single('bumps', isdream=(alg=='dream'))
78
79    def test_bumps_de(self):
80        self.fit_bumps('de')
81
82    def test_bumps_dream(self):
83        self.fit_bumps('dream', burn=500, steps=100)
84
85    def test_bumps_amoeba(self):
86        self.fit_bumps('amoeba')
87
88    def test_bumps_newton(self):
89        self.fit_bumps('newton')
90
91    def test_bumps_lm(self):
92        self.fit_bumps('lm')
93
94    def test_scipy(self):
95        #print "fitting scipy"
96        self.fit_single('scipy')
97
98    def test_park(self):
99        #print "fitting park"
100        self.fit_single('park')
101
102       
103    def test2(self):
104        """ fit 2 data and 2 model with no constrainst"""
105        #load data
106        l = Loader()
107        data1=l.load("testdata_line.txt")
108        data1.name = data1.filename
109     
110        data2=l.load("testdata_line1.txt")
111        data2.name = data2.filename
112     
113        #Importing the Fit module
114        fitter = Fit('scipy')
115        # Receives the type of model for the fitting
116        model11  = LineModel()
117        model11.name= "M1"
118        model22  = LineModel()
119        model11.name= "M2"
120     
121        model1 = Model(model11,data1)
122        model2 = Model(model22,data2)
123        #fit with scipy test
124        pars1= ['A','B']
125        fitter.set_data(data1,1)
126        fitter.set_model(model1,1,pars1)
127        fitter.select_problem_for_fit(id=1,value=0)
128        fitter.set_data(data2,2)
129        fitter.set_model(model2,2,pars1)
130        fitter.select_problem_for_fit(id=2,value=0)
131       
132        try: result1, = fitter.fit(handler=FitHandler())
133        except RuntimeError,msg:
134           assert str(msg)=="No Assembly scheduled for Scipy fitting."
135        else: raise AssertionError,"No error raised for scipy fitting with no model"
136        fitter.select_problem_for_fit(id=1,value=1)
137        fitter.select_problem_for_fit(id=2,value=1)
138        try: result1, = fitter.fit(handler=FitHandler())
139        except RuntimeError,msg:
140           assert str(msg)=="Scipy can't fit more than a single fit problem at a time."
141        else: raise AssertionError,"No error raised for scipy fitting with more than 2 models"
142
143        #fit with park test
144        fitter = Fit('park')
145        fitter.set_data(data1,1)
146        fitter.set_model(model1,1,pars1)
147        fitter.set_data(data2,2)
148        fitter.set_model(model2,2,pars1)
149        fitter.select_problem_for_fit(id=1,value=1)
150        fitter.select_problem_for_fit(id=2,value=1)
151        R1,R2 = fitter.fit(handler=FitHandler())
152       
153        self.assertTrue( math.fabs(R1.pvec[0]-4)/3 <= R1.stderr[0] )
154        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3 <= R1.stderr[1] )
155        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
156       
157       
158    def test3(self):
159        """ fit 2 data and 2 model with 1 constrainst"""
160        #load data
161        l = Loader()
162        data1= l.load("testdata_line.txt")
163        data1.name = data1.filename
164        data2= l.load("testdata_cst.txt")
165        data2.name = data2.filename
166       
167        # Receives the type of model for the fitting
168        model11  = LineModel()
169        model11.name= "line"
170        model11.setParam("A", 1.0)
171        model11.setParam("B",1.0)
172       
173        model22  = Constant()
174        model22.name= "cst"
175        model22.setParam("value", 1.0)
176       
177        model1 = Model(model11,data1)
178        model2 = Model(model22,data2)
179        model1.set(A=4)
180        model1.set(B=3)
181        # Constraint the constant value to be equal to parameter B (the real value is 2.5)
182        model2.set(value='line.B')
183        #fit with scipy test
184        pars1= ['A','B']
185        pars2= ['value']
186       
187        #Importing the Fit module
188        fitter = Fit('park')
189        fitter.set_data(data1,1)
190        fitter.set_model(model1,1,pars1)
191        fitter.set_data(data2,2,smearer=None)
192        fitter.set_model(model2,2,pars2)
193        fitter.select_problem_for_fit(id=1,value=1)
194        fitter.select_problem_for_fit(id=2,value=1)
195       
196        R1,R2 = fitter.fit(handler=FitHandler())
197        self.assertTrue( math.fabs(R1.pvec[0]-4.0)/3. <= R1.stderr[0])
198        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3. <= R1.stderr[1])
199        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
200       
201       
202    def test4(self):
203        """ fit 2 data concatenates with limited range of x and  one model """
204            #load data
205        l = Loader()
206        data1 = l.load("testdata_line.txt")
207        data1.name = data1.filename
208        data2 = l.load("testdata_line1.txt")
209        data2.name = data2.filename
210
211        # Receives the type of model for the fitting
212        model1  = LineModel()
213        model1.name= "M1"
214        model1.setParam("A", 1.0)
215        model1.setParam("B",1.0)
216        model = Model(model1,data1)
217     
218        #fit with scipy test
219        pars1= ['A','B']
220        #Importing the Fit module
221        fitter = Fit('scipy')
222        fitter.set_data(data1,1,qmin=0, qmax=7)
223        fitter.set_model(model,1,pars1)
224        fitter.set_data(data2,1,qmin=1,qmax=10)
225        fitter.select_problem_for_fit(id=1,value=1)
226       
227        result1, = fitter.fit(handler=FitHandler())
228        #print(result1)
229        self.assert_(result1)
230
231        self.assertTrue( math.fabs(result1.pvec[0]-4)/3 <= result1.stderr[0] )
232        self.assertTrue( math.fabs(result1.pvec[1]-2.5)/3 <= result1.stderr[1])
233        self.assertTrue( result1.fitness/len(data1.x) < 2 )
234
235        #fit with park test
236        fitter = Fit('park')
237        fitter.set_data(data1,1,qmin=0, qmax=7)
238        fitter.set_model(model,1,pars1)
239        fitter.set_data(data2,1,qmin=1,qmax=10)
240        fitter.select_problem_for_fit(id=1,value=1)
241        result2, = fitter.fit(handler=FitHandler())
242       
243        self.assert_(result2)
244        self.assertTrue( math.fabs(result2.pvec[0]-4)/3 <= result2.stderr[0] )
245        self.assertTrue( math.fabs(result2.pvec[1]-2.5)/3 <= result2.stderr[1] )
246        self.assertTrue( result2.fitness/len(data1.x) < 2)
247        # compare fit result result for scipy and park
248        self.assertAlmostEquals( result1.pvec[0], result2.pvec[0] )
249        self.assertAlmostEquals( result1.pvec[1],result2.pvec[1] )
250        self.assertAlmostEquals( result1.stderr[0],result2.stderr[0] )
251        self.assertAlmostEquals( result1.stderr[1],result2.stderr[1] )
252        self.assertTrue( result2.fitness/(len(data2.x)+len(data1.x)) < 2 )
253
254
255if __name__ == "__main__":
256    unittest.main()
257   
Note: See TracBrowser for help on using the repository browser.