source: sasview/src/sas/qtgui/Plotting/LinearFit.py @ 304e42f

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 304e42f was e4c475b7, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Minor fixes

  • Property mode set to 100644
File size: 8.6 KB
RevLine 
[570a58f9]1"""
2Adds a linear fit plot to the chart
3"""
[fed94a2]4import re
[570a58f9]5import numpy
[4992ff2]6from PyQt5 import QtCore
7from PyQt5 import QtGui
8from PyQt5 import QtWidgets
[570a58f9]9
[d6b8a1d]10from sas.qtgui.Utilities.GuiUtils import formatNumber, DoubleValidator
[570a58f9]11
[dc5ef15]12from sas.qtgui.Plotting import Fittings
13from sas.qtgui.Plotting import DataTransform
14from sas.qtgui.Plotting.LineModel import LineModel
[e4c475b7]15import sas.qtgui.Utilities.GuiUtils as GuiUtils
[570a58f9]16
17# Local UI
[cd2cc745]18from sas.qtgui.UI import main_resources_rc
[83eb5208]19from sas.qtgui.Plotting.UI.LinearFitUI import Ui_LinearFitUI
[570a58f9]20
[4992ff2]21class LinearFit(QtWidgets.QDialog, Ui_LinearFitUI):
[7969b9c]22    updatePlot = QtCore.pyqtSignal(tuple)
[570a58f9]23    def __init__(self, parent=None,
24                 data=None,
25                 max_range=(0.0, 0.0),
26                 fit_range=(0.0, 0.0),
27                 xlabel="",
28                 ylabel=""):
29        super(LinearFit, self).__init__()
30
31        self.setupUi(self)
32        assert(isinstance(max_range, tuple))
33        assert(isinstance(fit_range, tuple))
34
35        self.data = data
36        self.parent = parent
37
38        self.max_range = max_range
39        self.fit_range = fit_range
40        self.xLabel = xlabel
41        self.yLabel = ylabel
42
[b46f285]43        self.x_is_log = self.xLabel == "log10(x)"
44        self.y_is_log = self.yLabel == "log10(y)"
45
[d6b8a1d]46        self.txtFitRangeMin.setValidator(DoubleValidator())
47        self.txtFitRangeMax.setValidator(DoubleValidator())
[570a58f9]48
49        # Default values in the line edits
50        self.txtA.setText("1")
51        self.txtB.setText("1")
52        self.txtAerr.setText("0")
53        self.txtBerr.setText("0")
54        self.txtChi2.setText("0")
[fed94a2]55
[570a58f9]56        # Initial ranges
57        self.txtRangeMin.setText(str(max_range[0]))
58        self.txtRangeMax.setText(str(max_range[1]))
[e4c475b7]59        # Assure nice display of ranges
60        fr_min = GuiUtils.formatNumber(fit_range[0])
61        fr_max = GuiUtils.formatNumber(fit_range[1])
62        self.txtFitRangeMin.setText(str(fr_min))
63        self.txtFitRangeMax.setText(str(fr_max))
[570a58f9]64
[fed94a2]65        # cast xLabel into html
66        label = re.sub(r'\^\((.)\)(.*)', r'<span style=" vertical-align:super;">\1</span>\2',
67                      str(self.xLabel).rstrip())
68        self.lblRange.setText('Fit range of ' + label)
[570a58f9]69
70        self.model = LineModel()
71        # Display the fittings values
72        self.default_A = self.model.getParam('A')
73        self.default_B = self.model.getParam('B')
[dc5ef15]74        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
75        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
76        self.transform = DataTransform
[570a58f9]77
[2e3e959]78        self.setFixedSize(self.minimumSizeHint())
79
[fed94a2]80        # connect Fit button
[570a58f9]81        self.cmdFit.clicked.connect(self.fit)
82
83    def setRangeLabel(self, label=""):
84        """
85        Overwrite default fit range label to correspond to actual unit
86        """
[b3e8629]87        assert(isinstance(label, str))
[570a58f9]88        self.lblRange.setText(label)
89
90    def range(self):
91        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
92
93    def fit(self, event):
94        """
95        Performs the fit. Receive an event when clicking on
96        the button Fit.Computes chisqr ,
97        A and B parameters of the best linear fit y=Ax +B
98        Push a plottable to the caller
99        """
100        tempx = []
101        tempy = []
102        tempdy = []
103
104        # Checks to assure data correctness
105        if len(self.data.view.x) < 2:
106            return
107        if not self.checkFitValues(self.txtFitRangeMin):
108            return
109
110        self.xminFit, self.xmaxFit = self.range()
111
[b46f285]112        xmin = self.xminFit
113        xmax = self.xmaxFit
114        xminView = xmin
115        xmaxView = xmax
116
[570a58f9]117        # Set the qmin and qmax in the panel that matches the
118        # transformed min and max
[fed94a2]119        value_xmin = self.floatInvTransform(xmin)
120        value_xmax = self.floatInvTransform(xmax)
121        self.txtRangeMin.setText(formatNumber(value_xmin))
122        self.txtRangeMax.setText(formatNumber(value_xmax))
[570a58f9]123
[b46f285]124        tempx, tempy, tempdy = self.origData()
[570a58f9]125
126        # Find the fitting parameters
[dc5ef15]127        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
128        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
[570a58f9]129        tempdy = numpy.asarray(tempdy)
130        tempdy[tempdy == 0] = 1
131
[b46f285]132        if self.x_is_log:
133            xmin = numpy.log10(xmin)
134            xmax = numpy.log10(xmax)
135
[dc5ef15]136        chisqr, out, cov = Fittings.sasfit(self.model,
[b46f285]137                                           [self.cstA, self.cstB],
138                                           tempx, tempy, tempdy,
139                                           xmin, xmax)
[570a58f9]140        # Use chi2/dof
141        if len(tempx) > 0:
142            chisqr = chisqr / len(tempx)
143
144        # Check that cov and out are iterable before displaying them
145        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
146        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
147        cstA = out[0] if out is not None else 0.0
148        cstB = out[1] if out is not None else 0.0
149
150        # Reset model with the right values of A and B
151        self.model.setParam('A', float(cstA))
152        self.model.setParam('B', float(cstB))
153
154        tempx = []
155        tempy = []
156        y_model = 0.0
157
[b46f285]158        # load tempy with the minimum transformation
159        y_model = self.model.run(xmin)
160        tempx.append(xminView)
[d6b8a1d]161        tempy.append(numpy.power(10.0, y_model) if self.y_is_log else y_model)
[570a58f9]162
163        # load tempy with the maximum transformation
[b46f285]164        y_model = self.model.run(xmax)
165        tempx.append(xmaxView)
[d6b8a1d]166        tempy.append(numpy.power(10.0, y_model) if self.y_is_log else y_model)
[570a58f9]167
168        # Set the fit parameter display when  FitDialog is opened again
169        self.Avalue = cstA
170        self.Bvalue = cstB
171        self.ErrAvalue = errA
172        self.ErrBvalue = errB
173        self.Chivalue = chisqr
174
175        # Update the widget
176        self.txtA.setText(formatNumber(self.Avalue))
177        self.txtAerr.setText(formatNumber(self.ErrAvalue))
178        self.txtB.setText(formatNumber(self.Bvalue))
179        self.txtBerr.setText(formatNumber(self.ErrBvalue))
180        self.txtChi2.setText(formatNumber(self.Chivalue))
181
[7969b9c]182        self.updatePlot.emit((tempx, tempy))
[570a58f9]183
[b46f285]184    def origData(self):
185        # Store the transformed values of view x, y and dy before the fit
186        xmin_check = numpy.log10(self.xminFit)
187        # Local shortcuts
188        x = self.data.view.x
189        y = self.data.view.y
190        dy = self.data.view.dy
191
192        if self.y_is_log:
193            if self.x_is_log:
194                tempy  = [numpy.log10(y[i])
195                         for i in range(len(x)) if x[i] >= xmin_check]
[dc5ef15]196                tempdy = [DataTransform.errToLogX(y[i], 0, dy[i], 0)
[b46f285]197                         for i in range(len(x)) if x[i] >= xmin_check]
198            else:
[b3e8629]199                tempy = list(map(numpy.log10, y))
200                tempdy = list(map(lambda t1,t2:DataTransform.errToLogX(t1,0,t2,0),y,dy))
[b46f285]201        else:
202            tempy = y
203            tempdy = dy
204
205        if self.x_is_log:
206            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
207        else:
208            tempx = x
209
210        return tempx, tempy, tempdy
211
[570a58f9]212    def checkFitValues(self, item):
213        """
214        Check the validity of input values
215        """
216        flag = True
217        value = item.text()
218        p_white = item.palette()
219        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
220        p_pink = item.palette()
221        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
[b46f285]222        item.setAutoFillBackground(True)
[570a58f9]223        # Check for possible values entered
[b46f285]224        if self.x_is_log:
[570a58f9]225            if float(value) > 0:
226                item.setPalette(p_white)
227            else:
228                flag = False
229                item.setPalette(p_pink)
230        return flag
231
232    def floatInvTransform(self, x):
233        """
234        transform a float.It is used to determine the x.View min and x.View
235        max for values not in x.  Also used to properly calculate RgQmin,
236        RgQmax and to update qmin and qmax in the linear range boxes on the
237        panel.
238
239        """
240        # TODO: refactor this. This is just a hack to make the
241        # functionality work without rewritting the whole code
242        # with good design (which really should be done...).
243        if self.xLabel == "x":
244            return x
245        elif self.xLabel == "x^(2)":
246            return numpy.sqrt(x)
247        elif self.xLabel == "x^(4)":
[b46f285]248            return numpy.sqrt(numpy.sqrt(x))
[570a58f9]249        elif self.xLabel == "log10(x)":
[d6b8a1d]250            return numpy.power(10.0, x)
[570a58f9]251        elif self.xLabel == "ln(x)":
252            return numpy.exp(x)
253        elif self.xLabel == "log10(x^(4))":
[d6b8a1d]254            return numpy.sqrt(numpy.sqrt(numpy.power(10.0, x)))
[570a58f9]255        return x
[2e3e959]256
257
Note: See TracBrowser for help on using the repository browser.