source: sasview/test/sasfit/test/utest_fit_line.py @ b699768

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b699768 was b699768, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 8 years ago

Initial commit of the refactored SasCalc? module.

  • Property mode set to 100644
File size: 7.4 KB
Line 
1"""
2    Unit tests for fitting module
3    @author Gervaise Alina
4"""
5import unittest
6import math
7
8from sas.sascalc.fit.AbstractFitEngine import Model, FitHandler
9from sas.sascalc.dataloader.loader import Loader
10from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit
11from sas.models.LineModel import LineModel
12from sas.models.Constant import Constant
13
14from bumps import fitters
15try:
16    from bumps.options import FIT_CONFIG
17    def set_fitter(alg, opts):
18        FIT_CONFIG.selected_id = alg
19        FIT_CONFIG.values[alg].update(opts, monitors=[])
20except:
21    # CRUFT: Bumps changed its handling of fit options around 0.7.5.6
22    def set_fitter(alg, opts):
23        #print "fitting",alg,opts
24        #print "options",fitters.FIT_OPTIONS[alg].__dict__
25        fitters.FIT_DEFAULT = alg
26        fitters.FIT_OPTIONS[alg].options.update(opts, monitors=[])
27
28
29class testFitModule(unittest.TestCase):
30    """ test fitting """
31
32    def test_bad_pars(self):
33        fitter = Fit()
34
35        data = Loader().load("testdata_line.txt")
36        data.name = data.filename
37        fitter.set_data(data,1)
38
39        model1  = LineModel()
40        model1.name = "M1"
41        model = Model(model1, data)
42        pars1= ['param1','param2']
43        try:
44            fitter.set_model(model,1,pars1)
45        except ValueError,exc:
46            #print "ValueError was correctly raised: "+str(msg)
47            assert str(exc).startswith('parameter param1')
48        else:
49            raise AssertionError("No error raised for fitting with wrong parameters name to fit")
50
51    def fit_single(self, isdream=False):
52        fitter = Fit()
53
54        data = Loader().load("testdata_line.txt")
55        data.name = data.filename
56        fitter.set_data(data,1)
57
58        # Receives the type of model for the fitting
59        model1  = LineModel()
60        model1.name = "M1"
61        model = Model(model1,data)
62
63        pars1= ['A','B']
64        fitter.set_model(model,1,pars1)
65        fitter.select_problem_for_fit(id=1,value=1)
66        result1, = fitter.fit(handler=FitHandler())
67
68        # The target values were generated from the following statements
69        p,s,fx = result1.pvec, result1.stderr, result1.fitness
70        #print "p0,p1,s0,s1,fx = %g, %g, %g, %g, %g"%(p[0],p[1],s[0],s[1],fx)
71        p0,p1,s0,s1,fx_ = 3.68353, 2.61004, 0.336186, 0.105244, 1.20189
72
73        if isdream:
74            # Dream is not a minimizer: just check that the fit is within
75            # uncertainty
76            self.assertTrue( abs(p[0]-p0) <= s0 )
77            self.assertTrue( abs(p[1]-p1) <= s1 )
78        else:
79            self.assertTrue( abs(p[0]-p0) <= 1e-5 )
80            self.assertTrue( abs(p[1]-p1) <= 1e-5 )
81            self.assertTrue( abs(fx-fx_) <= 1e-5 )
82
83    def fit_bumps(self, alg, **opts):
84        set_fitter(alg, opts)
85        self.fit_single(isdream=(alg=='dream'))
86
87    def test_bumps_de(self):
88        self.fit_bumps('de')
89
90    def test_bumps_dream(self):
91        self.fit_bumps('dream', burn=500, steps=100)
92
93    def test_bumps_amoeba(self):
94        self.fit_bumps('amoeba')
95
96    def test_bumps_newton(self):
97        self.fit_bumps('newton')
98
99    def test_bumps_lm(self):
100        self.fit_bumps('lm')
101
102    def test2(self):
103        """ fit 2 data and 2 model with no constrainst"""
104        #load data
105        l = Loader()
106        data1=l.load("testdata_line.txt")
107        data1.name = data1.filename
108     
109        data2=l.load("testdata_line1.txt")
110        data2.name = data2.filename
111     
112        #Importing the Fit module
113        fitter = Fit()
114        # Receives the type of model for the fitting
115        model11  = LineModel()
116        model11.name= "M1"
117        model22  = LineModel()
118        model11.name= "M2"
119     
120        model1 = Model(model11,data1)
121        model2 = Model(model22,data2)
122        pars1= ['A','B']
123        fitter.set_data(data1,1)
124        fitter.set_model(model1,1,pars1)
125        fitter.select_problem_for_fit(id=1,value=0)
126        fitter.set_data(data2,2)
127        fitter.set_model(model2,2,pars1)
128        fitter.select_problem_for_fit(id=2,value=0)
129
130        try: result1, = fitter.fit(handler=FitHandler())
131        except RuntimeError,msg:
132            assert str(msg)=="Nothing to fit"
133        else: raise AssertionError,"No error raised for fitting with no model"
134        fitter.select_problem_for_fit(id=1,value=1)
135        fitter.select_problem_for_fit(id=2,value=1)
136        R1,R2 = fitter.fit(handler=FitHandler())
137       
138        self.assertTrue( math.fabs(R1.pvec[0]-4)/3 <= R1.stderr[0] )
139        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3 <= R1.stderr[1] )
140        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
141       
142       
143    def test_constraints(self):
144        """ fit 2 data and 2 model with 1 constrainst"""
145        #load data
146        l = Loader()
147        data1= l.load("testdata_line.txt")
148        data1.name = data1.filename
149        data2= l.load("testdata_cst.txt")
150        data2.name = data2.filename
151       
152        # Receives the type of model for the fitting
153        model11  = LineModel()
154        model11.name= "line"
155        model11.setParam("A", 1.0)
156        model11.setParam("B",1.0)
157       
158        model22  = Constant()
159        model22.name= "cst"
160        model22.setParam("value", 1.0)
161       
162        model1 = Model(model11,data1)
163        model2 = Model(model22,data2)
164        model1.set(A=4)
165        model1.set(B=3)
166        # Constraint the constant value to be equal to parameter B (the real value is 2.5)
167        #model2.set(value='line.B')
168        pars1= ['A','B']
169        pars2= ['value']
170       
171        #Importing the Fit module
172        fitter = Fit()
173        fitter.set_data(data1,1)
174        fitter.set_model(model1,1,pars1)
175        fitter.set_data(data2,2,smearer=None)
176        fitter.set_model(model2,2,pars2,constraints=[("value","line.B")])
177        fitter.select_problem_for_fit(id=1,value=1)
178        fitter.select_problem_for_fit(id=2,value=1)
179       
180        R1,R2 = fitter.fit(handler=FitHandler())
181        self.assertTrue( math.fabs(R1.pvec[0]-4.0)/3. <= R1.stderr[0])
182        self.assertTrue( math.fabs(R1.pvec[1]-2.5)/3. <= R1.stderr[1])
183        self.assertTrue( R1.fitness/(len(data1.x)+len(data2.x)) < 2)
184       
185       
186    def test4(self):
187        """ fit 2 data concatenates with limited range of x and  one model """
188            #load data
189        l = Loader()
190        data1 = l.load("testdata_line.txt")
191        data1.name = data1.filename
192        data2 = l.load("testdata_line1.txt")
193        data2.name = data2.filename
194
195        # Receives the type of model for the fitting
196        model1  = LineModel()
197        model1.name= "M1"
198        model1.setParam("A", 1.0)
199        model1.setParam("B",1.0)
200        model = Model(model1,data1)
201     
202        pars1= ['A','B']
203        #Importing the Fit module
204
205        fitter = Fit()
206        fitter.set_data(data1,1,qmin=0, qmax=7)
207        fitter.set_model(model,1,pars1)
208        fitter.set_data(data2,1,qmin=1,qmax=10)
209        fitter.select_problem_for_fit(id=1,value=1)
210        result2, = fitter.fit(handler=FitHandler())
211       
212        self.assert_(result2)
213        self.assertTrue( math.fabs(result2.pvec[0]-4)/3 <= result2.stderr[0] )
214        self.assertTrue( math.fabs(result2.pvec[1]-2.5)/3 <= result2.stderr[1] )
215        self.assertTrue( result2.fitness/len(data1.x) < 2)
216
217
218if __name__ == "__main__":
219    unittest.main()
220   
Note: See TracBrowser for help on using the repository browser.