source: sasview/src/sas/qtgui/LinearFit.py @ fed94a2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since fed94a2 was fed94a2, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Improved label formatting in charts

  • Property mode set to 100644
File size: 9.3 KB
Line 
1"""
2Adds a linear fit plot to the chart
3"""
4import re
5import numpy
6from PyQt4 import QtGui
7from PyQt4 import QtCore
8
9#from sas.sasgui.transform import
10from sas.qtgui.GuiUtils import formatNumber
11from sas.sasgui.plottools import fittings
12from sas.sasgui.plottools import transform
13
14from sas.sasgui.plottools.LineModel import LineModel
15
16# Local UI
17from sas.qtgui.UI.LinearFitUI import Ui_LinearFitUI
18
19class LinearFit(QtGui.QDialog, Ui_LinearFitUI):
20    def __init__(self, parent=None,
21                 data=None,
22                 max_range=(0.0, 0.0),
23                 fit_range=(0.0, 0.0),
24                 xlabel="",
25                 ylabel=""):
26        super(LinearFit, self).__init__()
27
28        self.setupUi(self)
29        assert(isinstance(max_range, tuple))
30        assert(isinstance(fit_range, tuple))
31
32        self.data = data
33        self.parent = parent
34
35        self.max_range = max_range
36        self.fit_range = fit_range
37        self.xLabel = xlabel
38        self.yLabel = ylabel
39
40        self.txtFitRangeMin.setValidator(QtGui.QDoubleValidator())
41        self.txtFitRangeMax.setValidator(QtGui.QDoubleValidator())
42
43        # Default values in the line edits
44        self.txtA.setText("1")
45        self.txtB.setText("1")
46        self.txtAerr.setText("0")
47        self.txtBerr.setText("0")
48        self.txtChi2.setText("0")
49
50        # Initial ranges
51        self.txtRangeMin.setText(str(max_range[0]))
52        self.txtRangeMax.setText(str(max_range[1]))
53        self.txtFitRangeMin.setText(str(fit_range[0]))
54        self.txtFitRangeMax.setText(str(fit_range[1]))
55
56        # cast xLabel into html
57        label = re.sub(r'\^\((.)\)(.*)', r'<span style=" vertical-align:super;">\1</span>\2',
58                      str(self.xLabel).rstrip())
59        self.lblRange.setText('Fit range of ' + label)
60
61        self.model = LineModel()
62        # Display the fittings values
63        self.default_A = self.model.getParam('A')
64        self.default_B = self.model.getParam('B')
65        self.cstA = fittings.Parameter(self.model, 'A', self.default_A)
66        self.cstB = fittings.Parameter(self.model, 'B', self.default_B)
67        self.transform = transform
68
69        # connect Fit button
70        self.cmdFit.clicked.connect(self.fit)
71
72    def setRangeLabel(self, label=""):
73        """
74        Overwrite default fit range label to correspond to actual unit
75        """
76        assert(isinstance(label, basestring))
77        self.lblRange.setText(label)
78
79    def a(self):
80        return (float(self.txtA.text()), float(self.txtAerr.text()))
81
82    def b(self):
83        return (float(self.txtB.text()), float(self.txtBerr.text()))
84
85    def chi(self):
86        return float(self.txtChi2.text())
87
88    def range(self):
89        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
90
91    def fit(self, event):
92        """
93        Performs the fit. Receive an event when clicking on
94        the button Fit.Computes chisqr ,
95        A and B parameters of the best linear fit y=Ax +B
96        Push a plottable to the caller
97        """
98        tempx = []
99        tempy = []
100        tempdy = []
101
102        # Checks to assure data correctness
103        if len(self.data.view.x) < 2:
104            return
105        if not self.checkFitValues(self.txtFitRangeMin):
106            return
107
108        self.xminFit, self.xmaxFit = self.range()
109
110        xminView = self.xminFit
111        xmaxView = self.xmaxFit
112        xmin = xminView
113        xmax = xmaxView
114        # Set the qmin and qmax in the panel that matches the
115        # transformed min and max
116        #value_xmin = X_VAL_DICT[self.xLabel].floatTransform(xmin)
117        #value_xmax = X_VAL_DICT[self.xLabel].floatTransform(xmax)
118        value_xmin = self.floatInvTransform(xmin)
119        value_xmax = self.floatInvTransform(xmax)
120        self.txtRangeMin.setText(formatNumber(value_xmin))
121        self.txtRangeMax.setText(formatNumber(value_xmax))
122
123        # Store the transformed values of view x, y and dy before the fit
124        xmin_check = numpy.log10(xmin)
125        x = self.data.view.x
126        y = self.data.view.y
127        dy = self.data.view.dy
128
129        if self.yLabel == "log10(y)":
130            if self.xLabel == "log10(x)":
131                tempy  = [numpy.log10(y[i])
132                         for i in range(len(x)) if x[i] >= xmin_check]
133                tempdy = [transform.errToLogX(y[i], 0, dy[i], 0)
134                         for i in range(len(x)) if x[i] >= xmin_check]
135            else:
136                tempy = map(numpy.log10, y)
137                tempdy = map(lambda t1,t2:transform.errToLogX(t1,0,t2,0),y,dy)
138        else:
139            tempy = y
140            tempdy = dy
141
142        if self.xLabel == "log10(x)":
143            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
144        else:
145            tempx = self.data.view.x
146
147        # Find the fitting parameters
148        # Always use the same defaults, so that fit history
149        # doesn't play a role!
150        self.cstA = fittings.Parameter(self.model, 'A', self.default_A)
151        self.cstB = fittings.Parameter(self.model, 'B', self.default_B)
152        tempdy = numpy.asarray(tempdy)
153        tempdy[tempdy == 0] = 1
154
155        if self.xLabel == "log10(x)":
156            chisqr, out, cov = fittings.sasfit(self.model,
157                                               [self.cstA, self.cstB],
158                                               tempx, tempy,
159                                               tempdy,
160                                               numpy.log10(xmin),
161                                               numpy.log10(xmax))
162        else:
163            chisqr, out, cov = fittings.sasfit(self.model,
164                                               [self.cstA, self.cstB],
165                                               tempx, tempy, tempdy,
166                                               xminView, xmaxView)
167        # Use chi2/dof
168        if len(tempx) > 0:
169            chisqr = chisqr / len(tempx)
170
171        # Check that cov and out are iterable before displaying them
172        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
173        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
174        cstA = out[0] if out is not None else 0.0
175        cstB = out[1] if out is not None else 0.0
176
177        # Reset model with the right values of A and B
178        self.model.setParam('A', float(cstA))
179        self.model.setParam('B', float(cstB))
180
181        tempx = []
182        tempy = []
183        y_model = 0.0
184        # load tempy with the minimum transformation
185        if self.xLabel == "log10(x)":
186            y_model = self.model.run(numpy.log10(xmin))
187            tempx.append(xmin)
188        else:
189            y_model = self.model.run(xminView)
190            tempx.append(xminView)
191
192        if self.yLabel == "log10(y)":
193            tempy.append(numpy.power(10, y_model))
194        else:
195            tempy.append(y_model)
196
197        # load tempy with the maximum transformation
198        if self.xLabel == "log10(x)":
199            y_model = self.model.run(numpy.log10(xmax))
200            tempx.append(xmax)
201        else:
202            y_model = self.model.run(xmaxView)
203            tempx.append(xmaxView)
204
205        if self.yLabel == "log10(y)":
206            tempy.append(numpy.power(10, y_model))
207        else:
208            tempy.append(y_model)
209        # Set the fit parameter display when  FitDialog is opened again
210        self.Avalue = cstA
211        self.Bvalue = cstB
212        self.ErrAvalue = errA
213        self.ErrBvalue = errB
214        self.Chivalue = chisqr
215
216        # Update the widget
217        self.txtA.setText(formatNumber(self.Avalue))
218        self.txtAerr.setText(formatNumber(self.ErrAvalue))
219        self.txtB.setText(formatNumber(self.Bvalue))
220        self.txtBerr.setText(formatNumber(self.ErrBvalue))
221        self.txtChi2.setText(formatNumber(self.Chivalue))
222
223        #self.parent.updatePlot.emit((tempx, tempy))
224        self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy))
225
226    def checkFitValues(self, item):
227        """
228        Check the validity of input values
229        """
230        flag = True
231        value = item.text()
232        p_white = item.palette()
233        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
234        p_pink = item.palette()
235        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
236        # Check for possible values entered
237        if self.xLabel == "log10(x)":
238            if float(value) > 0:
239                item.setPalette(p_white)
240            else:
241                flag = False
242                item.setPalette(p_pink)
243        return flag
244
245    def floatInvTransform(self, x):
246        """
247        transform a float.It is used to determine the x.View min and x.View
248        max for values not in x.  Also used to properly calculate RgQmin,
249        RgQmax and to update qmin and qmax in the linear range boxes on the
250        panel.
251
252        """
253        # TODO: refactor this. This is just a hack to make the
254        # functionality work without rewritting the whole code
255        # with good design (which really should be done...).
256        if self.xLabel == "x":
257            return x
258        elif self.xLabel == "x^(2)":
259            return numpy.sqrt(x)
260        elif self.xLabel == "x^(4)":
261            return numpy.sqrt(math.sqrt(x))
262        elif self.xLabel == "log10(x)":
263            return numpy.power(10, x)
264        elif self.xLabel == "ln(x)":
265            return numpy.exp(x)
266        elif self.xLabel == "log10(x^(4))":
267            return numpy.sqrt(numpy.sqrt(numpy.power(10, x)))
268        return x
Note: See TracBrowser for help on using the repository browser.