source: sasview/src/sas/qtgui/Plotting/LinearFit.py @ 4992ff2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 4992ff2 was 4992ff2, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Initial, in-progress version. Not really working atm. SASVIEW-787

  • Property mode set to 100644
File size: 8.5 KB
Line 
1"""
2Adds a linear fit plot to the chart
3"""
4import re
5import numpy
6from PyQt5 import QtCore
7from PyQt5 import QtGui
8from PyQt5 import QtWidgets
9
10from sas.qtgui.Utilities.GuiUtils import formatNumber
11
12from sas.qtgui.Plotting import Fittings
13from sas.qtgui.Plotting import DataTransform
14from sas.qtgui.Plotting.LineModel import LineModel
15
16# Local UI
17from sas.qtgui.UI import main_resources_rc
18from sas.qtgui.Plotting.UI.LinearFitUI import Ui_LinearFitUI
19
20class LinearFit(QtWidgets.QDialog, Ui_LinearFitUI):
21    def __init__(self, parent=None,
22                 data=None,
23                 max_range=(0.0, 0.0),
24                 fit_range=(0.0, 0.0),
25                 xlabel="",
26                 ylabel=""):
27        super(LinearFit, self).__init__()
28
29        self.setupUi(self)
30        assert(isinstance(max_range, tuple))
31        assert(isinstance(fit_range, tuple))
32
33        self.data = data
34        self.parent = parent
35
36        self.max_range = max_range
37        self.fit_range = fit_range
38        self.xLabel = xlabel
39        self.yLabel = ylabel
40
41        self.x_is_log = self.xLabel == "log10(x)"
42        self.y_is_log = self.yLabel == "log10(y)"
43
44        self.txtFitRangeMin.setValidator(QtGui.QDoubleValidator())
45        self.txtFitRangeMax.setValidator(QtGui.QDoubleValidator())
46
47        # Default values in the line edits
48        self.txtA.setText("1")
49        self.txtB.setText("1")
50        self.txtAerr.setText("0")
51        self.txtBerr.setText("0")
52        self.txtChi2.setText("0")
53
54        # Initial ranges
55        self.txtRangeMin.setText(str(max_range[0]))
56        self.txtRangeMax.setText(str(max_range[1]))
57        self.txtFitRangeMin.setText(str(fit_range[0]))
58        self.txtFitRangeMax.setText(str(fit_range[1]))
59
60        # cast xLabel into html
61        label = re.sub(r'\^\((.)\)(.*)', r'<span style=" vertical-align:super;">\1</span>\2',
62                      str(self.xLabel).rstrip())
63        self.lblRange.setText('Fit range of ' + label)
64
65        self.model = LineModel()
66        # Display the fittings values
67        self.default_A = self.model.getParam('A')
68        self.default_B = self.model.getParam('B')
69        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
70        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
71        self.transform = DataTransform
72
73        self.setFixedSize(self.minimumSizeHint())
74
75        # connect Fit button
76        self.cmdFit.clicked.connect(self.fit)
77
78    def setRangeLabel(self, label=""):
79        """
80        Overwrite default fit range label to correspond to actual unit
81        """
82        assert(isinstance(label, str))
83        self.lblRange.setText(label)
84
85    def range(self):
86        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
87
88    def fit(self, event):
89        """
90        Performs the fit. Receive an event when clicking on
91        the button Fit.Computes chisqr ,
92        A and B parameters of the best linear fit y=Ax +B
93        Push a plottable to the caller
94        """
95        tempx = []
96        tempy = []
97        tempdy = []
98
99        # Checks to assure data correctness
100        if len(self.data.view.x) < 2:
101            return
102        if not self.checkFitValues(self.txtFitRangeMin):
103            return
104
105        self.xminFit, self.xmaxFit = self.range()
106
107        xmin = self.xminFit
108        xmax = self.xmaxFit
109        xminView = xmin
110        xmaxView = xmax
111
112        # Set the qmin and qmax in the panel that matches the
113        # transformed min and max
114        #value_xmin = X_VAL_DICT[self.xLabel].floatTransform(xmin)
115        #value_xmax = X_VAL_DICT[self.xLabel].floatTransform(xmax)
116
117        value_xmin = self.floatInvTransform(xmin)
118        value_xmax = self.floatInvTransform(xmax)
119        self.txtRangeMin.setText(formatNumber(value_xmin))
120        self.txtRangeMax.setText(formatNumber(value_xmax))
121
122        tempx, tempy, tempdy = self.origData()
123
124        # Find the fitting parameters
125        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
126        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
127        tempdy = numpy.asarray(tempdy)
128        tempdy[tempdy == 0] = 1
129
130        if self.x_is_log:
131            xmin = numpy.log10(xmin)
132            xmax = numpy.log10(xmax)
133
134        chisqr, out, cov = Fittings.sasfit(self.model,
135                                           [self.cstA, self.cstB],
136                                           tempx, tempy, tempdy,
137                                           xmin, xmax)
138        # Use chi2/dof
139        if len(tempx) > 0:
140            chisqr = chisqr / len(tempx)
141
142        # Check that cov and out are iterable before displaying them
143        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
144        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
145        cstA = out[0] if out is not None else 0.0
146        cstB = out[1] if out is not None else 0.0
147
148        # Reset model with the right values of A and B
149        self.model.setParam('A', float(cstA))
150        self.model.setParam('B', float(cstB))
151
152        tempx = []
153        tempy = []
154        y_model = 0.0
155
156        # load tempy with the minimum transformation
157        y_model = self.model.run(xmin)
158        tempx.append(xminView)
159        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model)
160
161        # load tempy with the maximum transformation
162        y_model = self.model.run(xmax)
163        tempx.append(xmaxView)
164        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model)
165
166        # Set the fit parameter display when  FitDialog is opened again
167        self.Avalue = cstA
168        self.Bvalue = cstB
169        self.ErrAvalue = errA
170        self.ErrBvalue = errB
171        self.Chivalue = chisqr
172
173        # Update the widget
174        self.txtA.setText(formatNumber(self.Avalue))
175        self.txtAerr.setText(formatNumber(self.ErrAvalue))
176        self.txtB.setText(formatNumber(self.Bvalue))
177        self.txtBerr.setText(formatNumber(self.ErrBvalue))
178        self.txtChi2.setText(formatNumber(self.Chivalue))
179
180        #self.parent.updatePlot.emit((tempx, tempy))
181        self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy))
182
183    def origData(self):
184        # Store the transformed values of view x, y and dy before the fit
185        xmin_check = numpy.log10(self.xminFit)
186        # Local shortcuts
187        x = self.data.view.x
188        y = self.data.view.y
189        dy = self.data.view.dy
190
191        if self.y_is_log:
192            if self.x_is_log:
193                tempy  = [numpy.log10(y[i])
194                         for i in range(len(x)) if x[i] >= xmin_check]
195                tempdy = [DataTransform.errToLogX(y[i], 0, dy[i], 0)
196                         for i in range(len(x)) if x[i] >= xmin_check]
197            else:
198                tempy = list(map(numpy.log10, y))
199                tempdy = list(map(lambda t1,t2:DataTransform.errToLogX(t1,0,t2,0),y,dy))
200        else:
201            tempy = y
202            tempdy = dy
203
204        if self.x_is_log:
205            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
206        else:
207            tempx = x
208
209        return tempx, tempy, tempdy
210
211    def checkFitValues(self, item):
212        """
213        Check the validity of input values
214        """
215        flag = True
216        value = item.text()
217        p_white = item.palette()
218        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
219        p_pink = item.palette()
220        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
221        item.setAutoFillBackground(True)
222        # Check for possible values entered
223        if self.x_is_log:
224            if float(value) > 0:
225                item.setPalette(p_white)
226            else:
227                flag = False
228                item.setPalette(p_pink)
229        return flag
230
231    def floatInvTransform(self, x):
232        """
233        transform a float.It is used to determine the x.View min and x.View
234        max for values not in x.  Also used to properly calculate RgQmin,
235        RgQmax and to update qmin and qmax in the linear range boxes on the
236        panel.
237
238        """
239        # TODO: refactor this. This is just a hack to make the
240        # functionality work without rewritting the whole code
241        # with good design (which really should be done...).
242        if self.xLabel == "x":
243            return x
244        elif self.xLabel == "x^(2)":
245            return numpy.sqrt(x)
246        elif self.xLabel == "x^(4)":
247            return numpy.sqrt(numpy.sqrt(x))
248        elif self.xLabel == "log10(x)":
249            return numpy.power(10, x)
250        elif self.xLabel == "ln(x)":
251            return numpy.exp(x)
252        elif self.xLabel == "log10(x^(4))":
253            return numpy.sqrt(numpy.sqrt(numpy.power(10, x)))
254        return x
255
256
Note: See TracBrowser for help on using the repository browser.