source: sasview/src/sas/qtgui/Plotting/LinearFit.py @ 9c0ce68

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 9c0ce68 was d6b8a1d, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

More Qt5 related fixes

  • Property mode set to 100644
File size: 8.5 KB
RevLine 
[570a58f9]1"""
2Adds a linear fit plot to the chart
3"""
[fed94a2]4import re
[570a58f9]5import numpy
[4992ff2]6from PyQt5 import QtCore
7from PyQt5 import QtGui
8from PyQt5 import QtWidgets
[570a58f9]9
[d6b8a1d]10from sas.qtgui.Utilities.GuiUtils import formatNumber, DoubleValidator
[570a58f9]11
[dc5ef15]12from sas.qtgui.Plotting import Fittings
13from sas.qtgui.Plotting import DataTransform
14from sas.qtgui.Plotting.LineModel import LineModel
[570a58f9]15
16# Local UI
[cd2cc745]17from sas.qtgui.UI import main_resources_rc
[83eb5208]18from sas.qtgui.Plotting.UI.LinearFitUI import Ui_LinearFitUI
[570a58f9]19
[4992ff2]20class LinearFit(QtWidgets.QDialog, Ui_LinearFitUI):
[7969b9c]21    updatePlot = QtCore.pyqtSignal(tuple)
[570a58f9]22    def __init__(self, parent=None,
23                 data=None,
24                 max_range=(0.0, 0.0),
25                 fit_range=(0.0, 0.0),
26                 xlabel="",
27                 ylabel=""):
28        super(LinearFit, self).__init__()
29
30        self.setupUi(self)
31        assert(isinstance(max_range, tuple))
32        assert(isinstance(fit_range, tuple))
33
34        self.data = data
35        self.parent = parent
36
37        self.max_range = max_range
38        self.fit_range = fit_range
39        self.xLabel = xlabel
40        self.yLabel = ylabel
41
[b46f285]42        self.x_is_log = self.xLabel == "log10(x)"
43        self.y_is_log = self.yLabel == "log10(y)"
44
[d6b8a1d]45        self.txtFitRangeMin.setValidator(DoubleValidator())
46        self.txtFitRangeMax.setValidator(DoubleValidator())
[570a58f9]47
48        # Default values in the line edits
49        self.txtA.setText("1")
50        self.txtB.setText("1")
51        self.txtAerr.setText("0")
52        self.txtBerr.setText("0")
53        self.txtChi2.setText("0")
[fed94a2]54
[570a58f9]55        # Initial ranges
56        self.txtRangeMin.setText(str(max_range[0]))
57        self.txtRangeMax.setText(str(max_range[1]))
58        self.txtFitRangeMin.setText(str(fit_range[0]))
59        self.txtFitRangeMax.setText(str(fit_range[1]))
60
[fed94a2]61        # cast xLabel into html
62        label = re.sub(r'\^\((.)\)(.*)', r'<span style=" vertical-align:super;">\1</span>\2',
63                      str(self.xLabel).rstrip())
64        self.lblRange.setText('Fit range of ' + label)
[570a58f9]65
66        self.model = LineModel()
67        # Display the fittings values
68        self.default_A = self.model.getParam('A')
69        self.default_B = self.model.getParam('B')
[dc5ef15]70        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
71        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
72        self.transform = DataTransform
[570a58f9]73
[2e3e959]74        self.setFixedSize(self.minimumSizeHint())
75
[fed94a2]76        # connect Fit button
[570a58f9]77        self.cmdFit.clicked.connect(self.fit)
78
79    def setRangeLabel(self, label=""):
80        """
81        Overwrite default fit range label to correspond to actual unit
82        """
[b3e8629]83        assert(isinstance(label, str))
[570a58f9]84        self.lblRange.setText(label)
85
86    def range(self):
87        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
88
89    def fit(self, event):
90        """
91        Performs the fit. Receive an event when clicking on
92        the button Fit.Computes chisqr ,
93        A and B parameters of the best linear fit y=Ax +B
94        Push a plottable to the caller
95        """
96        tempx = []
97        tempy = []
98        tempdy = []
99
100        # Checks to assure data correctness
101        if len(self.data.view.x) < 2:
102            return
103        if not self.checkFitValues(self.txtFitRangeMin):
104            return
105
106        self.xminFit, self.xmaxFit = self.range()
107
[b46f285]108        xmin = self.xminFit
109        xmax = self.xmaxFit
110        xminView = xmin
111        xmaxView = xmax
112
[570a58f9]113        # Set the qmin and qmax in the panel that matches the
114        # transformed min and max
[fed94a2]115        #value_xmin = X_VAL_DICT[self.xLabel].floatTransform(xmin)
116        #value_xmax = X_VAL_DICT[self.xLabel].floatTransform(xmax)
[b46f285]117
[fed94a2]118        value_xmin = self.floatInvTransform(xmin)
119        value_xmax = self.floatInvTransform(xmax)
120        self.txtRangeMin.setText(formatNumber(value_xmin))
121        self.txtRangeMax.setText(formatNumber(value_xmax))
[570a58f9]122
[b46f285]123        tempx, tempy, tempdy = self.origData()
[570a58f9]124
125        # Find the fitting parameters
[dc5ef15]126        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
127        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
[570a58f9]128        tempdy = numpy.asarray(tempdy)
129        tempdy[tempdy == 0] = 1
130
[b46f285]131        if self.x_is_log:
132            xmin = numpy.log10(xmin)
133            xmax = numpy.log10(xmax)
134
[dc5ef15]135        chisqr, out, cov = Fittings.sasfit(self.model,
[b46f285]136                                           [self.cstA, self.cstB],
137                                           tempx, tempy, tempdy,
138                                           xmin, xmax)
[570a58f9]139        # Use chi2/dof
140        if len(tempx) > 0:
141            chisqr = chisqr / len(tempx)
142
143        # Check that cov and out are iterable before displaying them
144        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
145        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
146        cstA = out[0] if out is not None else 0.0
147        cstB = out[1] if out is not None else 0.0
148
149        # Reset model with the right values of A and B
150        self.model.setParam('A', float(cstA))
151        self.model.setParam('B', float(cstB))
152
153        tempx = []
154        tempy = []
155        y_model = 0.0
156
[b46f285]157        # load tempy with the minimum transformation
158        y_model = self.model.run(xmin)
159        tempx.append(xminView)
[d6b8a1d]160        tempy.append(numpy.power(10.0, y_model) if self.y_is_log else y_model)
[570a58f9]161
162        # load tempy with the maximum transformation
[b46f285]163        y_model = self.model.run(xmax)
164        tempx.append(xmaxView)
[d6b8a1d]165        tempy.append(numpy.power(10.0, y_model) if self.y_is_log else y_model)
[570a58f9]166
167        # Set the fit parameter display when  FitDialog is opened again
168        self.Avalue = cstA
169        self.Bvalue = cstB
170        self.ErrAvalue = errA
171        self.ErrBvalue = errB
172        self.Chivalue = chisqr
173
174        # Update the widget
175        self.txtA.setText(formatNumber(self.Avalue))
176        self.txtAerr.setText(formatNumber(self.ErrAvalue))
177        self.txtB.setText(formatNumber(self.Bvalue))
178        self.txtBerr.setText(formatNumber(self.ErrBvalue))
179        self.txtChi2.setText(formatNumber(self.Chivalue))
180
[7969b9c]181        self.updatePlot.emit((tempx, tempy))
[570a58f9]182
[b46f285]183    def origData(self):
184        # Store the transformed values of view x, y and dy before the fit
185        xmin_check = numpy.log10(self.xminFit)
186        # Local shortcuts
187        x = self.data.view.x
188        y = self.data.view.y
189        dy = self.data.view.dy
190
191        if self.y_is_log:
192            if self.x_is_log:
193                tempy  = [numpy.log10(y[i])
194                         for i in range(len(x)) if x[i] >= xmin_check]
[dc5ef15]195                tempdy = [DataTransform.errToLogX(y[i], 0, dy[i], 0)
[b46f285]196                         for i in range(len(x)) if x[i] >= xmin_check]
197            else:
[b3e8629]198                tempy = list(map(numpy.log10, y))
199                tempdy = list(map(lambda t1,t2:DataTransform.errToLogX(t1,0,t2,0),y,dy))
[b46f285]200        else:
201            tempy = y
202            tempdy = dy
203
204        if self.x_is_log:
205            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
206        else:
207            tempx = x
208
209        return tempx, tempy, tempdy
210
[570a58f9]211    def checkFitValues(self, item):
212        """
213        Check the validity of input values
214        """
215        flag = True
216        value = item.text()
217        p_white = item.palette()
218        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
219        p_pink = item.palette()
220        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
[b46f285]221        item.setAutoFillBackground(True)
[570a58f9]222        # Check for possible values entered
[b46f285]223        if self.x_is_log:
[570a58f9]224            if float(value) > 0:
225                item.setPalette(p_white)
226            else:
227                flag = False
228                item.setPalette(p_pink)
229        return flag
230
231    def floatInvTransform(self, x):
232        """
233        transform a float.It is used to determine the x.View min and x.View
234        max for values not in x.  Also used to properly calculate RgQmin,
235        RgQmax and to update qmin and qmax in the linear range boxes on the
236        panel.
237
238        """
239        # TODO: refactor this. This is just a hack to make the
240        # functionality work without rewritting the whole code
241        # with good design (which really should be done...).
242        if self.xLabel == "x":
243            return x
244        elif self.xLabel == "x^(2)":
245            return numpy.sqrt(x)
246        elif self.xLabel == "x^(4)":
[b46f285]247            return numpy.sqrt(numpy.sqrt(x))
[570a58f9]248        elif self.xLabel == "log10(x)":
[d6b8a1d]249            return numpy.power(10.0, x)
[570a58f9]250        elif self.xLabel == "ln(x)":
251            return numpy.exp(x)
252        elif self.xLabel == "log10(x^(4))":
[d6b8a1d]253            return numpy.sqrt(numpy.sqrt(numpy.power(10.0, x)))
[570a58f9]254        return x
[2e3e959]255
256
Note: See TracBrowser for help on using the repository browser.