source: sasview/src/sas/qtgui/Plotting/UnitTesting/LinearFitTest.py @ aea6bb7

ESS_GUIESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_openclESS_GUI_sync_sascalc
Last change on this file since aea6bb7 was 53c771e, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Converted unit tests

  • Property mode set to 100644
File size: 5.9 KB
Line 
1import sys
2import unittest
3import numpy
4
5from PyQt5 import QtGui, QtWidgets
6from unittest.mock import MagicMock
7
8from UnitTesting.TestUtils import QtSignalSpy
9
10# set up import paths
11import path_prepare
12
13from sas.qtgui.Plotting.PlotterData import Data1D
14import sas.qtgui.Plotting.Plotter as Plotter
15
16# Local
17from sas.qtgui.Plotting.LinearFit import LinearFit
18
19if not QtWidgets.QApplication.instance():
20    app = QtWidgets.QApplication(sys.argv)
21
22class LinearFitTest(unittest.TestCase):
23    '''Test the LinearFit'''
24    def setUp(self):
25        '''Create the LinearFit'''
26        self.data = Data1D(x=[1.0, 2.0, 3.0],
27                           y=[10.0, 11.0, 12.0],
28                           dx=[0.1, 0.2, 0.3],
29                           dy=[0.1, 0.2, 0.3])
30        plotter = Plotter.Plotter(None, quickplot=True)
31        self.widget = LinearFit(parent=plotter, data=self.data, xlabel="log10(x^2)", ylabel="log10(y)")
32
33    def tearDown(self):
34        '''Destroy the GUI'''
35        self.widget.close()
36        self.widget = None
37
38    def testDefaults(self):
39        '''Test the GUI in its default state'''
40        self.assertIsInstance(self.widget, QtWidgets.QDialog)
41        self.assertEqual(self.widget.windowTitle(), "Linear Fit")
42        self.assertEqual(self.widget.txtA.text(), "1")
43        self.assertEqual(self.widget.txtB.text(), "1")
44        self.assertEqual(self.widget.txtAerr.text(), "0")
45        self.assertEqual(self.widget.txtBerr.text(), "0")
46
47        self.assertEqual(self.widget.lblRange.text(), "Fit range of log10(x^2)")
48
49    def testFit(self):
50        '''Test the fitting wrapper '''
51        # Catch the update signal
52        #self.widget.updatePlot.emit = MagicMock()
53        #self.widget.updatePlot.emit = MagicMock()
54        spy_update = QtSignalSpy(self.widget, self.widget.updatePlot)
55
56        # Set some initial values
57        self.widget.txtRangeMin.setText("1.0")
58        self.widget.txtRangeMax.setText("3.0")
59        self.widget.txtFitRangeMin.setText("1.0")
60        self.widget.txtFitRangeMax.setText("3.0")
61        # Run the fitting
62        self.widget.fit(None)
63
64        # Expected one spy instance
65        self.assertEqual(spy_update.count(), 1)
66
67        return_values = spy_update.called()[0]['args'][0]
68        # Compare
69        self.assertCountEqual(return_values[0], [1.0, 3.0])
70        self.assertAlmostEqual(return_values[1][0], 10.004054329, 6)
71        self.assertAlmostEqual(return_values[1][1], 12.030439848, 6)
72
73        # Set the log scale
74        self.widget.x_is_log = True
75        self.widget.fit(None)
76        self.assertEqual(spy_update.count(), 2)
77        return_values = spy_update.called()[1]['args'][0]
78        # Compare
79        self.assertCountEqual(return_values[0], [1.0, 3.0])
80        self.assertAlmostEqual(return_values[1][0], 9.987732937, 6)
81        self.assertAlmostEqual(return_values[1][1], 11.84365082, 6)
82
83    def testOrigData(self):
84        ''' Assure the un-logged data is returned'''
85        # log(x), log(y)
86        self.widget.xminFit, self.widget.xmaxFit = self.widget.range()
87        orig_x = [ 1.,  2.,  3.]
88        orig_y = [1.0, 1.0413926851582251, 1.0791812460476249]
89        orig_dy = [0.01, 0.018181818181818184, 0.024999999999999998]
90        x, y, dy = self.widget.origData()
91
92        self.assertCountEqual(x, orig_x)
93        self.assertEqual(y[0], orig_y[0])
94        self.assertAlmostEqual(y[1], orig_y[1], 8)
95        self.assertAlmostEqual(y[2], orig_y[2], 8)
96        self.assertEqual(dy[0], orig_dy[0])
97        self.assertAlmostEqual(dy[1], orig_dy[1], 8)
98        self.assertAlmostEqual(dy[2], orig_dy[2], 8)
99
100        # x, y
101        self.widget.x_is_log = False
102        self.widget.y_is_log = False
103        self.widget.xminFit, self.widget.xmaxFit = self.widget.range()
104        orig_x = [ 1.,  2.,  3.]
105        orig_y = [10., 11., 12.]
106        orig_dy = [0.1, 0.2, 0.3]
107        x, y, dy = self.widget.origData()
108
109        self.assertCountEqual(x, orig_x)
110        self.assertCountEqual(y, orig_y)
111        self.assertCountEqual(dy, orig_dy)
112
113        # x, log(y)
114        self.widget.x_is_log = False
115        self.widget.y_is_log = True
116        self.widget.xminFit, self.widget.xmaxFit = self.widget.range()
117        orig_x = [ 1.,  2.,  3.]
118        orig_y = [1.0, 1.0413926851582251, 1.0791812460476249]
119        orig_dy = [0.01, 0.018181818181818184, 0.024999999999999998]
120        x, y, dy = self.widget.origData()
121
122        self.assertCountEqual(x, orig_x)
123        self.assertEqual(y[0], orig_y[0])
124        self.assertAlmostEqual(y[1], orig_y[1], 8)
125        self.assertAlmostEqual(y[2], orig_y[2], 8)
126        self.assertEqual(dy[0], orig_dy[0])
127        self.assertAlmostEqual(dy[1], orig_dy[1], 8)
128        self.assertAlmostEqual(dy[2], orig_dy[2], 8)
129
130    def testCheckFitValues(self):
131        '''Assure fit values are correct'''
132        # Good values
133        self.assertTrue(self.widget.checkFitValues(self.widget.txtFitRangeMin))
134        # Colors platform dependent
135        #self.assertEqual(self.widget.txtFitRangeMin.palette().color(10).name(), "#f0f0f0")
136        # Bad values
137        self.widget.x_is_log = True
138        self.widget.txtFitRangeMin.setText("-1.0")
139        self.assertFalse(self.widget.checkFitValues(self.widget.txtFitRangeMin))
140       
141
142    def testFloatInvTransform(self):
143        '''Test the helper method for providing conversion function'''
144        self.widget.xLabel="x"
145        self.assertEqual(self.widget.floatInvTransform(5.0), 5.0)
146        self.widget.xLabel="x^(2)"
147        self.assertEqual(self.widget.floatInvTransform(25.0), 5.0)
148        self.widget.xLabel="x^(4)"
149        self.assertEqual(self.widget.floatInvTransform(81.0), 3.0)
150        self.widget.xLabel="log10(x)"
151        self.assertEqual(self.widget.floatInvTransform(2.0), 100.0)
152        self.widget.xLabel="ln(x)"
153        self.assertEqual(self.widget.floatInvTransform(1.0), numpy.exp(1))
154        self.widget.xLabel="log10(x^(4))"
155        self.assertEqual(self.widget.floatInvTransform(4.0), 10.0)
156     
157if __name__ == "__main__":
158    unittest.main()
Note: See TracBrowser for help on using the repository browser.