source: sasview/park_integration/test/batch_fit.py @ f368fd9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f368fd9 was d1c8036, checked in by Gervaise Alina <gervyh@…>, 13 years ago

update unittest

  • Property mode set to 100644
File size: 6.0 KB
RevLine 
[e5df560]1
[d1c8036]2
[e5df560]3import math
4import numpy
5import copy
6import time
7import unittest
[d48858da]8from sans.dataloader.loader import Loader
[e5df560]9from sans.fit.Fitting import Fit
10from sans.models.CylinderModel import CylinderModel
11import sans.models.dispersion_models 
12from sans.models.qsmearing import smear_selection
13
14NPTS = 1
15
[d48858da]16
17
18   
[e5df560]19def classMapper(classInstance, classFunc, *args):
20    """
21    Take an instance of a class and a function name as a string.
22    Execute class.function and return result
23    """
24    return  getattr(classInstance,classFunc)(*args)
25
26def mapapply(arguments):
27    return apply(arguments[0], arguments[1:])
28
29
[d48858da]30
[e5df560]31class BatchScipyFit:
32    """
[d48858da]33    test fit module
[e5df560]34    """
35    def __init__(self, qmin=None, qmax=None):
36        """ """
37        self.list_of_fitter = []
38        self.list_of_function = []
39        self.param_to_fit = ['scale', 'length', 'radius']
40        self.list_of_constraints = []
41        self.list_of_mapper = []
42        self.polydisp = sans.models.dispersion_models.models
43        self.qmin = qmin
44        self.qmax = qmin
45        self.reset_value()
46       
47    def set_range(self, qmin=None, qmax=None): 
48        self.qmin = qmin
49        self.qmax = qmax
50       
51    def _reset_helper(self, path=None, engine="scipy", npts=NPTS):
52        """
53        Set value to fitter engine and prepare inputs for map function
54        """
55        for i in range(npts):
56            data = Loader().load(path)
57            fitter = Fit(engine)
58            #create model
59            model = CylinderModel()
60            model.setParam('scale', 1.0)
61            model.setParam('radius', 20.0)
62            model.setParam('length', 400.0)
63            model.setParam('sldCyl', 4e-006)
64            model.setParam('sldSolv', 1e-006)
65            model.setParam('background', 0.0)
66            for param in model.dispersion.keys():
67                model.set_dispersion(param, self.polydisp['gaussian']())
68            model.setParam('cyl_phi.width', 10)
69            model.setParam('cyl_phi.npts', 3)
70            model.setParam('cyl_theta.nsigmas', 10)
71            """ for 2 data cyl_theta = 60.0 [deg] cyl_phi= 60.0 [deg]"""
72            fitter.set_model(model, i, self.param_to_fit, 
73                             self.list_of_constraints)
74            #smear data
75            current_smearer = smear_selection(data, model)
[d1c8036]76            import cPickle
77            p = cPickle.dumps(current_smearer)
78            sm = cPickle.loads(p)
[e5df560]79            fitter.set_data(data=data, id=i,
80                             smearer=current_smearer, qmin=self.qmin, qmax=self.qmax)
81            fitter.select_problem_for_fit(id=i, value=1)
82            self.list_of_fitter.append(copy.deepcopy(fitter))
83            self.list_of_function.append('fit')
84            self.list_of_mapper.append(classMapper)
85                   
86    def reset_value(self):
87        """
88        Initialize inputs for the map function
89        """
90        self.list_of_fitter = []
91        self.list_of_function = []
92        self.param_to_fit = ['scale', 'length', 'radius']
93        self.list_of_constraints = []
94        self.list_of_mapper = []
[d48858da]95        engine ="scipy"
[e5df560]96       
97        path = "testdata_line3.txt"
[d48858da]98        self._reset_helper(path=path, engine=engine, npts=NPTS)
[e5df560]99        path = "testdata_line.txt"
[d48858da]100        self._reset_helper(path=path, engine=engine, npts=NPTS)
[e5df560]101        path = "SILIC010_noheader.DAT"
[d48858da]102        self._reset_helper(path=path, engine=engine, npts=NPTS)
[e5df560]103        path = "cyl_400_20.txt"
[d48858da]104        self._reset_helper(path=path, engine=engine, npts=NPTS)
[e5df560]105        path = "sphere_80.txt"
[d48858da]106        self._reset_helper(path=path, engine=engine, npts=NPTS)
[e5df560]107        path = "PolySpheres.txt"
[d48858da]108        self._reset_helper(path=path, engine=engine, npts=NPTS)
109        path = "latex_qdev.txt"
110        self._reset_helper(path=path, engine=engine, npts=NPTS)
111        path = "latex_qdev2.txt"
112        self._reset_helper(path=path, engine=engine, npts=NPTS)
[d1c8036]113       
[e5df560]114     
115    def test_map_fit(self):
116        """
117        """ 
118        results =  map(classMapper,self.list_of_fitter, self.list_of_function)
[d48858da]119        print len(results)
120        for result in results:
121            print result.fitness, result.stderr, result.pvec
[e5df560]122       
123    def test_process_map_fit(self, n=1):
124        """
125        run fit usong map , n is the number of processes used
126        """ 
127        t0 = time.time()
128        print "start fit with %s process(es) at %s" % (str(n), time.strftime(" %H:%M:%S", time.localtime(t0)))
129        from multiprocessing import Pool
130        temp = zip(self.list_of_mapper, self.list_of_fitter, self.list_of_function)
131        results =  Pool(n).map(func=mapapply, 
132                               iterable=temp)
133        t1 = time.time()
[d48858da]134        print "got fit results ", time.strftime(" %H:%M:%S", time.localtime(t1)), t1 - t0
135        print len(results)
136        for result in results:
137            print result.fitness, result.stderr, result.pvec
138        t2 = time.time()
139        print "print fit1 results ", time.strftime(" %H:%M:%S", time.localtime(t2)), t2 - t1   
[e5df560]140               
141class testBatch(unittest.TestCase):
142    """
143    fitting
144    """ 
145    def setUp(self):
146        self.test = BatchScipyFit(qmin=None, qmax=None)
[d1c8036]147       
[e5df560]148   
[d1c8036]149    def __test_fit1(self):
[e5df560]150        """test fit with python built in map function---- full range of each data"""
151        self.test.test_map_fit()
152       
[d1c8036]153    def __test_fit2(self):
154        """test fit with python built in map function---- common range for all data"""
155        self.test.set_range(qmin=0.013, qmax=0.05)
156        self.test.reset_value()
157        self.test.test_map_fit()
[e5df560]158       
159    def test_fit3(self):
160        """test fit with data full range using 1 processor and map"""
161        self.test.set_range(qmin=None, qmax=None)
162        self.test.reset_value()
[d1c8036]163        self.test.test_process_map_fit(n=2)
[e5df560]164       
[d1c8036]165    def test_fit4(self):
166        """test fit with a common fixed range for data using 1 processor and map"""
167        self.test.set_range(qmin=-1, qmax=10)
168        self.test.reset_value()
169        self.test.test_process_map_fit(n=1)
[e5df560]170       
171           
172if __name__ == '__main__':
173   unittest.main()
174   
175   
176   
Note: See TracBrowser for help on using the repository browser.