Changeset b46f285 in sasview


Ignore:
Timestamp:
Jan 9, 2017 7:49:16 AM (7 years ago)
Author:
Piotr Rozyczko <rozyczko@…>
Branches:
ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc
Children:
87cc73a
Parents:
a66ff280
Message:

Unit tests for linear fit

Location:
src/sas/qtgui
Files:
1 added
9 edited

Legend:

Unmodified
Added
Removed
  • src/sas/qtgui/GUITests.py

    rd3ca363 rb46f285  
    2020from UnitTesting import WindowTitleTest 
    2121from UnitTesting import SetGraphRangeTest 
     22from UnitTesting import LinearFitTest 
    2223 
    2324def suite(): 
     
    4243        unittest.makeSuite(WindowTitleTest.WindowTitleTest, 'test'), 
    4344        unittest.makeSuite(SetGraphRangeTest.SetGraphRangeTest, 'test'), 
     45        unittest.makeSuite(LinearFitTest.LinearFitTest, 'test'), 
    4446    ) 
    4547    return unittest.TestSuite(suites) 
  • src/sas/qtgui/LinearFit.py

    rfed94a2 rb46f285  
    3838        self.yLabel = ylabel 
    3939 
     40        self.x_is_log = self.xLabel == "log10(x)" 
     41        self.y_is_log = self.yLabel == "log10(y)" 
     42 
    4043        self.txtFitRangeMin.setValidator(QtGui.QDoubleValidator()) 
    4144        self.txtFitRangeMax.setValidator(QtGui.QDoubleValidator()) 
     
    108111        self.xminFit, self.xmaxFit = self.range() 
    109112 
    110         xminView = self.xminFit 
    111         xmaxView = self.xmaxFit 
    112         xmin = xminView 
    113         xmax = xmaxView 
     113        xmin = self.xminFit 
     114        xmax = self.xmaxFit 
     115        xminView = xmin 
     116        xmaxView = xmax 
     117 
    114118        # Set the qmin and qmax in the panel that matches the 
    115119        # transformed min and max 
    116120        #value_xmin = X_VAL_DICT[self.xLabel].floatTransform(xmin) 
    117121        #value_xmax = X_VAL_DICT[self.xLabel].floatTransform(xmax) 
     122 
    118123        value_xmin = self.floatInvTransform(xmin) 
    119124        value_xmax = self.floatInvTransform(xmax) 
     
    121126        self.txtRangeMax.setText(formatNumber(value_xmax)) 
    122127 
     128        tempx, tempy, tempdy = self.origData() 
     129 
     130        # Find the fitting parameters 
     131        self.cstA = fittings.Parameter(self.model, 'A', self.default_A) 
     132        self.cstB = fittings.Parameter(self.model, 'B', self.default_B) 
     133        tempdy = numpy.asarray(tempdy) 
     134        tempdy[tempdy == 0] = 1 
     135 
     136        if self.x_is_log: 
     137            xmin = numpy.log10(xmin) 
     138            xmax = numpy.log10(xmax) 
     139 
     140        chisqr, out, cov = fittings.sasfit(self.model, 
     141                                           [self.cstA, self.cstB], 
     142                                           tempx, tempy, tempdy, 
     143                                           xmin, xmax) 
     144        # Use chi2/dof 
     145        if len(tempx) > 0: 
     146            chisqr = chisqr / len(tempx) 
     147 
     148        # Check that cov and out are iterable before displaying them 
     149        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0 
     150        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0 
     151        cstA = out[0] if out is not None else 0.0 
     152        cstB = out[1] if out is not None else 0.0 
     153 
     154        # Reset model with the right values of A and B 
     155        self.model.setParam('A', float(cstA)) 
     156        self.model.setParam('B', float(cstB)) 
     157 
     158        tempx = [] 
     159        tempy = [] 
     160        y_model = 0.0 
     161 
     162        # load tempy with the minimum transformation 
     163        y_model = self.model.run(xmin) 
     164        tempx.append(xminView) 
     165        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model) 
     166 
     167        # load tempy with the maximum transformation 
     168        y_model = self.model.run(xmax) 
     169        tempx.append(xmaxView) 
     170        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model) 
     171 
     172        # Set the fit parameter display when  FitDialog is opened again 
     173        self.Avalue = cstA 
     174        self.Bvalue = cstB 
     175        self.ErrAvalue = errA 
     176        self.ErrBvalue = errB 
     177        self.Chivalue = chisqr 
     178 
     179        # Update the widget 
     180        self.txtA.setText(formatNumber(self.Avalue)) 
     181        self.txtAerr.setText(formatNumber(self.ErrAvalue)) 
     182        self.txtB.setText(formatNumber(self.Bvalue)) 
     183        self.txtBerr.setText(formatNumber(self.ErrBvalue)) 
     184        self.txtChi2.setText(formatNumber(self.Chivalue)) 
     185 
     186        #self.parent.updatePlot.emit((tempx, tempy)) 
     187        self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy)) 
     188 
     189    def origData(self): 
    123190        # Store the transformed values of view x, y and dy before the fit 
    124         xmin_check = numpy.log10(xmin) 
     191        xmin_check = numpy.log10(self.xminFit) 
     192        # Local shortcuts 
    125193        x = self.data.view.x 
    126194        y = self.data.view.y 
    127195        dy = self.data.view.dy 
    128196 
    129         if self.yLabel == "log10(y)": 
    130             if self.xLabel == "log10(x)": 
     197        if self.y_is_log: 
     198            if self.x_is_log: 
    131199                tempy  = [numpy.log10(y[i]) 
    132200                         for i in range(len(x)) if x[i] >= xmin_check] 
     
    140208            tempdy = dy 
    141209 
    142         if self.xLabel == "log10(x)": 
     210        if self.x_is_log: 
    143211            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check] 
    144212        else: 
    145             tempx = self.data.view.x 
    146  
    147         # Find the fitting parameters 
    148         # Always use the same defaults, so that fit history 
    149         # doesn't play a role! 
    150         self.cstA = fittings.Parameter(self.model, 'A', self.default_A) 
    151         self.cstB = fittings.Parameter(self.model, 'B', self.default_B) 
    152         tempdy = numpy.asarray(tempdy) 
    153         tempdy[tempdy == 0] = 1 
    154  
    155         if self.xLabel == "log10(x)": 
    156             chisqr, out, cov = fittings.sasfit(self.model, 
    157                                                [self.cstA, self.cstB], 
    158                                                tempx, tempy, 
    159                                                tempdy, 
    160                                                numpy.log10(xmin), 
    161                                                numpy.log10(xmax)) 
    162         else: 
    163             chisqr, out, cov = fittings.sasfit(self.model, 
    164                                                [self.cstA, self.cstB], 
    165                                                tempx, tempy, tempdy, 
    166                                                xminView, xmaxView) 
    167         # Use chi2/dof 
    168         if len(tempx) > 0: 
    169             chisqr = chisqr / len(tempx) 
    170  
    171         # Check that cov and out are iterable before displaying them 
    172         errA = numpy.sqrt(cov[0][0]) if cov is not None else 0 
    173         errB = numpy.sqrt(cov[1][1]) if cov is not None else 0 
    174         cstA = out[0] if out is not None else 0.0 
    175         cstB = out[1] if out is not None else 0.0 
    176  
    177         # Reset model with the right values of A and B 
    178         self.model.setParam('A', float(cstA)) 
    179         self.model.setParam('B', float(cstB)) 
    180  
    181         tempx = [] 
    182         tempy = [] 
    183         y_model = 0.0 
    184         # load tempy with the minimum transformation 
    185         if self.xLabel == "log10(x)": 
    186             y_model = self.model.run(numpy.log10(xmin)) 
    187             tempx.append(xmin) 
    188         else: 
    189             y_model = self.model.run(xminView) 
    190             tempx.append(xminView) 
    191  
    192         if self.yLabel == "log10(y)": 
    193             tempy.append(numpy.power(10, y_model)) 
    194         else: 
    195             tempy.append(y_model) 
    196  
    197         # load tempy with the maximum transformation 
    198         if self.xLabel == "log10(x)": 
    199             y_model = self.model.run(numpy.log10(xmax)) 
    200             tempx.append(xmax) 
    201         else: 
    202             y_model = self.model.run(xmaxView) 
    203             tempx.append(xmaxView) 
    204  
    205         if self.yLabel == "log10(y)": 
    206             tempy.append(numpy.power(10, y_model)) 
    207         else: 
    208             tempy.append(y_model) 
    209         # Set the fit parameter display when  FitDialog is opened again 
    210         self.Avalue = cstA 
    211         self.Bvalue = cstB 
    212         self.ErrAvalue = errA 
    213         self.ErrBvalue = errB 
    214         self.Chivalue = chisqr 
    215  
    216         # Update the widget 
    217         self.txtA.setText(formatNumber(self.Avalue)) 
    218         self.txtAerr.setText(formatNumber(self.ErrAvalue)) 
    219         self.txtB.setText(formatNumber(self.Bvalue)) 
    220         self.txtBerr.setText(formatNumber(self.ErrBvalue)) 
    221         self.txtChi2.setText(formatNumber(self.Chivalue)) 
    222  
    223         #self.parent.updatePlot.emit((tempx, tempy)) 
    224         self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy)) 
     213            tempx = x 
     214 
     215        return tempx, tempy, tempdy 
    225216 
    226217    def checkFitValues(self, item): 
     
    234225        p_pink = item.palette() 
    235226        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128)) 
     227        item.setAutoFillBackground(True) 
    236228        # Check for possible values entered 
    237         if self.xLabel == "log10(x)": 
     229        if self.x_is_log: 
    238230            if float(value) > 0: 
    239231                item.setPalette(p_white) 
     
    259251            return numpy.sqrt(x) 
    260252        elif self.xLabel == "x^(4)": 
    261             return numpy.sqrt(math.sqrt(x)) 
     253            return numpy.sqrt(numpy.sqrt(x)) 
    262254        elif self.xLabel == "log10(x)": 
    263255            return numpy.power(10, x) 
  • src/sas/qtgui/Plotter.py

    ra66ff280 rb46f285  
    382382        plot_dict = copy.deepcopy(self.plot_dict) 
    383383 
     384        # Labels might have been changed 
     385        xl = self.ax.xaxis.label.get_text() 
     386        yl = self.ax.yaxis.label.get_text() 
     387 
    384388        self.plot_dict = {} 
    385389 
     
    389393        for ids in plot_dict: 
    390394            if ids != id: 
    391                 self.plot(data=plot_dict[ids], hide_error=plot_dict[ids].hide_error)                 
     395                self.plot(data=plot_dict[ids], hide_error=plot_dict[ids].hide_error) 
     396 
     397        # Reset the labels 
     398        self.ax.set_xlabel(xl) 
     399        self.ax.set_ylabel(yl) 
     400        self.canvas.draw() 
    392401 
    393402    def onFreeze(self, id): 
     
    443452            self.xscale = xscale 
    444453            self.yscale = yscale 
     454 
     455            # Plot the updated chart 
     456            self.removePlot(id) 
     457 
     458            # This assignment will wrap the label in Latex "$" 
    445459            self.xLabel = new_xlabel 
    446460            self.yLabel = new_ylabel 
    447             # Plot the updated chart 
    448             self.removePlot(id) 
    449461            # Directly overwrite the data to avoid label reassignment 
    450462            self._data = current_plot 
  • src/sas/qtgui/Plotter2D.py

    r9290b1a rb46f285  
    8585                      zmax=zmax_2D_temp) 
    8686 
    87     def contextMenu(self): 
     87    def createContextMenu(self): 
    8888        """ 
    8989        Define common context menu and associated actions for the MPL widget 
     
    9191        self.defaultContextMenu() 
    9292 
    93     def contextMenuQuickPlot(self): 
     93    def createContextMenuQuick(self): 
    9494        """ 
    9595        Define context menu and associated actions for the quickplot MPL widget 
  • src/sas/qtgui/UnitTesting/GuiUtilsTest.py

    r27313b7 rb46f285  
    346346        self.assertFalse(os.path.isfile(file_name)) 
    347347 
     348    def testXYTransform(self): 
     349        """ Assure the unit/legend transformation is correct""" 
     350        data = Data1D(x=[1.0, 2.0, 3.0], y=[10.0, 11.0, 12.0], 
     351                      dx=[0.1, 0.2, 0.3], dy=[0.1, 0.2, 0.3]) 
     352 
     353        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="y") 
     354        self.assertEqual(xLabel, "()") 
     355        self.assertEqual(xscale, "linear") 
     356        self.assertEqual(yscale, "linear") 
     357 
     358        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x^(2)", yLabel="1/y") 
     359        self.assertEqual(xLabel, "^{2}(()^{2})") 
     360        self.assertEqual(yLabel, "1/(()^{-1})") 
     361        self.assertEqual(xscale, "linear") 
     362        self.assertEqual(yscale, "linear") 
     363 
     364        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x^(4)", yLabel="ln(y)") 
     365        self.assertEqual(xLabel, "^{4}(()^{4})") 
     366        self.assertEqual(yLabel, "\\ln{()}()") 
     367        self.assertEqual(xscale, "linear") 
     368        self.assertEqual(yscale, "linear") 
     369 
     370        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="ln(x)", yLabel="y^(2)") 
     371        self.assertEqual(xLabel, "\\ln{()}()") 
     372        self.assertEqual(yLabel, "^{2}(()^{2})") 
     373        self.assertEqual(xscale, "linear") 
     374        self.assertEqual(yscale, "linear") 
     375 
     376        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="log10(x)", yLabel="y*x^(2)") 
     377        self.assertEqual(xLabel, "()") 
     378        self.assertEqual(yLabel, " \\ \\ ^{2}(()^{2})") 
     379        self.assertEqual(xscale, "log") 
     380        self.assertEqual(yscale, "linear") 
     381 
     382        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="log10(x^(4))", yLabel="y*x^(4)") 
     383        self.assertEqual(xLabel, "^{4}(()^{4})") 
     384        self.assertEqual(yLabel, " \\ \\ ^{4}(()^{16})") 
     385        self.assertEqual(xscale, "log") 
     386        self.assertEqual(yscale, "linear") 
     387 
     388        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="1/sqrt(y)") 
     389        self.assertEqual(yLabel, "1/\\sqrt{}(()^{-0.5})") 
     390        self.assertEqual(yscale, "linear") 
     391 
     392        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="log10(y)") 
     393        self.assertEqual(yLabel, "()") 
     394        self.assertEqual(yscale, "log") 
     395 
     396        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="ln(y*x)") 
     397        self.assertEqual(yLabel, "\\ln{( \\ \\ )}()") 
     398        self.assertEqual(yscale, "linear") 
     399 
     400        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="ln(y*x^(2))") 
     401        self.assertEqual(yLabel, "\\ln ( \\ \\ ^{2})(()^{2})") 
     402        self.assertEqual(yscale, "linear") 
     403 
     404        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="ln(y*x^(4))") 
     405        self.assertEqual(yLabel, "\\ln ( \\ \\ ^{4})(()^{4})") 
     406        self.assertEqual(yscale, "linear") 
     407 
     408        xLabel, yLabel, xscale, yscale = xyTransform(data, xLabel="x", yLabel="log10(y*x^(4))") 
     409        self.assertEqual(yLabel, " \\ \\ ^{4}(()^{4})") 
     410        self.assertEqual(yscale, "log") 
     411 
    348412class FormulaValidatorTest(unittest.TestCase): 
    349413    """ Test the formula validator """ 
  • src/sas/qtgui/UnitTesting/Plotter2DTest.py

    r27313b7 rb46f285  
    7474        """ Test the right click menu """ 
    7575        self.plotter.data = self.data 
     76        self.plotter.createContextMenuQuick() 
    7677        actions = self.plotter.contextMenu.actions() 
    7778        self.assertEqual(len(actions), 7) 
  • src/sas/qtgui/UnitTesting/PlotterTest.py

    raadf0af1 rb46f285  
    6969        self.assertTrue(FigureCanvas.draw.called) 
    7070 
    71     def testContextMenuQuickPlot(self): 
     71    def testCreateContextMenuQuick(self): 
    7272        """ Test the right click menu """ 
    7373        self.plotter.createContextMenuQuick() 
     
    112112        self.assertTrue(self.plotter.properties.exec_.called) 
    113113 
    114     def testXYTransform(self): 
    115         """ Assure the unit/legend transformation is correct""" 
    116         self.plotter.plot(self.data) 
    117  
    118         self.plotter.xyTransform(xLabel="x", yLabel="y") 
     114    def testXyTransform(self): 
     115        """ Tests the XY transformation and new chart update """ 
     116        self.plotter.plot(self.data) 
     117 
     118        # Transform the points 
     119        self.plotter.xyTransform(xLabel="x", yLabel="log10(y)") 
     120 
     121        # Assure new plot has correct labels 
    119122        self.assertEqual(self.plotter.ax.get_xlabel(), "$()$") 
    120123        self.assertEqual(self.plotter.ax.get_ylabel(), "$()$") 
    121  
    122         self.plotter.xyTransform(xLabel="x^(2)", yLabel="1/y") 
    123         self.assertEqual(self.plotter.ax.get_xlabel(), "$^{2}(()^{2})$") 
    124         self.assertEqual(self.plotter.ax.get_ylabel(), "$1/(()^{-1})$") 
    125  
    126         self.plotter.xyTransform(xLabel="x^(4)", yLabel="ln(y)") 
    127         self.assertEqual(self.plotter.ax.get_xlabel(), "$^{4}(()^{4})$") 
    128         self.assertEqual(self.plotter.ax.get_ylabel(), "$\\ln{()}()$") 
    129  
    130         self.plotter.xyTransform(xLabel="ln(x)", yLabel="y^(2)") 
    131         self.assertEqual(self.plotter.ax.get_xlabel(), "$\\ln{()}()$") 
    132         self.assertEqual(self.plotter.ax.get_ylabel(), "$^{2}(()^{2})$") 
    133  
    134         self.plotter.xyTransform(xLabel="log10(x)", yLabel="y*x^(2)") 
    135         self.assertEqual(self.plotter.ax.get_xlabel(), "$()$") 
    136         self.assertEqual(self.plotter.ax.get_ylabel(), "$ \\ \\ ^{2}(()^{2})$") 
    137  
    138         self.plotter.xyTransform(xLabel="log10(x^(4))", yLabel="y*x^(4)") 
    139         self.assertEqual(self.plotter.ax.get_xlabel(), "$^{4}(()^{4})$") 
    140         self.assertEqual(self.plotter.ax.get_ylabel(), "$ \\ \\ ^{4}(()^{16})$") 
    141  
    142         self.plotter.xyTransform(xLabel="x", yLabel="1/sqrt(y)") 
    143         self.assertEqual(self.plotter.ax.get_ylabel(), "$1/\\sqrt{}(()^{-0.5})$") 
    144  
    145         self.plotter.xyTransform(xLabel="x", yLabel="log10(y)") 
    146         self.assertEqual(self.plotter.ax.get_ylabel(), "$()$") 
    147  
    148         self.plotter.xyTransform(xLabel="x", yLabel="ln(y*x)") 
    149         self.assertEqual(self.plotter.ax.get_ylabel(), "$\\ln{( \\ \\ )}()$") 
    150  
    151         self.plotter.xyTransform(xLabel="x", yLabel="ln(y*x^(2))") 
    152         self.assertEqual(self.plotter.ax.get_ylabel(), "$\\ln ( \\ \\ ^{2})(()^{2})$") 
    153  
    154         self.plotter.xyTransform(xLabel="x", yLabel="ln(y*x^(4))") 
    155         self.assertEqual(self.plotter.ax.get_ylabel(), "$\\ln ( \\ \\ ^{4})(()^{4})$") 
    156  
    157         self.plotter.xyTransform(xLabel="x", yLabel="log10(y*x^(4))") 
    158         self.assertEqual(self.plotter.ax.get_ylabel(), "$ \\ \\ ^{4}(()^{4})$") 
     124        # ... and scale 
     125        self.assertEqual(self.plotter.xscale, "linear") 
     126        self.assertEqual(self.plotter.yscale, "log") 
     127        # See that just one plot is present 
     128        self.assertEqual(len(self.plotter.plot_dict), 1) 
     129        self.assertEqual(len(self.plotter.ax.collections), 1) 
    159130 
    160131    def testAddText(self): 
     
    250221        self.assertNotEqual(self.plotter.ax.get_ylim(), new_y) 
    251222 
     223    def testOnLinearFit(self): 
     224        """ Checks the response to LinearFit call """ 
     225        pass 
     226 
     227    def testOnRemovePlot(self): 
     228        """ Assure plots get removed when requested """ 
     229        pass 
     230 
     231    def testRemovePlot(self): 
     232        """ Test plot removal """ 
     233        pass 
     234 
     235    def testOnToggleHideError(self): 
     236        """ Test the error bar toggle on plots """ 
     237        pass 
     238 
     239    def testOnFitDisplay(self): 
     240        """ Test the fit line display on the chart """ 
     241        pass 
     242 
    252243if __name__ == "__main__": 
    253244    unittest.main() 
  • src/sas/qtgui/run_tests.bat

    rd3ca363 rb46f285  
    1919python -m UnitTesting.AddTextTest 
    2020python -m UnitTesting.SetGraphRangeTest 
     21python -m UnitTesting.LinearFitTest 
  • src/sas/qtgui/run_tests.sh

    rd3ca363 rb46f285  
    1818python -m UnitTesting.AddTextTest 
    1919python -m UnitTesting.SetGraphRangeTest 
     20python -m UnitTesting.LinearFitTest 
Note: See TracChangeset for help on using the changeset viewer.