source: sasview/src/sas/qtgui/Plotting/LinearFit.py @ 464cd07

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 464cd07 was dc5ef15, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Removed qtgui dependency on sasgui and wx SASVIEW-590

  • Property mode set to 100644
File size: 8.5 KB
RevLine 
[570a58f9]1"""
2Adds a linear fit plot to the chart
3"""
[fed94a2]4import re
[570a58f9]5import numpy
6from PyQt4 import QtGui
7from PyQt4 import QtCore
8
[83eb5208]9from sas.qtgui.Utilities.GuiUtils import formatNumber
[570a58f9]10
[dc5ef15]11from sas.qtgui.Plotting import Fittings
12from sas.qtgui.Plotting import DataTransform
13from sas.qtgui.Plotting.LineModel import LineModel
[570a58f9]14
15# Local UI
[cd2cc745]16from sas.qtgui.UI import main_resources_rc
[83eb5208]17from sas.qtgui.Plotting.UI.LinearFitUI import Ui_LinearFitUI
[570a58f9]18
19class LinearFit(QtGui.QDialog, Ui_LinearFitUI):
20    def __init__(self, parent=None,
21                 data=None,
22                 max_range=(0.0, 0.0),
23                 fit_range=(0.0, 0.0),
24                 xlabel="",
25                 ylabel=""):
26        super(LinearFit, self).__init__()
27
28        self.setupUi(self)
29        assert(isinstance(max_range, tuple))
30        assert(isinstance(fit_range, tuple))
31
32        self.data = data
33        self.parent = parent
34
35        self.max_range = max_range
36        self.fit_range = fit_range
37        self.xLabel = xlabel
38        self.yLabel = ylabel
39
[b46f285]40        self.x_is_log = self.xLabel == "log10(x)"
41        self.y_is_log = self.yLabel == "log10(y)"
42
[570a58f9]43        self.txtFitRangeMin.setValidator(QtGui.QDoubleValidator())
44        self.txtFitRangeMax.setValidator(QtGui.QDoubleValidator())
45
46        # Default values in the line edits
47        self.txtA.setText("1")
48        self.txtB.setText("1")
49        self.txtAerr.setText("0")
50        self.txtBerr.setText("0")
51        self.txtChi2.setText("0")
[fed94a2]52
[570a58f9]53        # Initial ranges
54        self.txtRangeMin.setText(str(max_range[0]))
55        self.txtRangeMax.setText(str(max_range[1]))
56        self.txtFitRangeMin.setText(str(fit_range[0]))
57        self.txtFitRangeMax.setText(str(fit_range[1]))
58
[fed94a2]59        # cast xLabel into html
60        label = re.sub(r'\^\((.)\)(.*)', r'<span style=" vertical-align:super;">\1</span>\2',
61                      str(self.xLabel).rstrip())
62        self.lblRange.setText('Fit range of ' + label)
[570a58f9]63
64        self.model = LineModel()
65        # Display the fittings values
66        self.default_A = self.model.getParam('A')
67        self.default_B = self.model.getParam('B')
[dc5ef15]68        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
69        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
70        self.transform = DataTransform
[570a58f9]71
[2e3e959]72        self.setFixedSize(self.minimumSizeHint())
73
[fed94a2]74        # connect Fit button
[570a58f9]75        self.cmdFit.clicked.connect(self.fit)
76
77    def setRangeLabel(self, label=""):
78        """
79        Overwrite default fit range label to correspond to actual unit
80        """
81        assert(isinstance(label, basestring))
82        self.lblRange.setText(label)
83
84    def range(self):
85        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
86
87    def fit(self, event):
88        """
89        Performs the fit. Receive an event when clicking on
90        the button Fit.Computes chisqr ,
91        A and B parameters of the best linear fit y=Ax +B
92        Push a plottable to the caller
93        """
94        tempx = []
95        tempy = []
96        tempdy = []
97
98        # Checks to assure data correctness
99        if len(self.data.view.x) < 2:
100            return
101        if not self.checkFitValues(self.txtFitRangeMin):
102            return
103
104        self.xminFit, self.xmaxFit = self.range()
105
[b46f285]106        xmin = self.xminFit
107        xmax = self.xmaxFit
108        xminView = xmin
109        xmaxView = xmax
110
[570a58f9]111        # Set the qmin and qmax in the panel that matches the
112        # transformed min and max
[fed94a2]113        #value_xmin = X_VAL_DICT[self.xLabel].floatTransform(xmin)
114        #value_xmax = X_VAL_DICT[self.xLabel].floatTransform(xmax)
[b46f285]115
[fed94a2]116        value_xmin = self.floatInvTransform(xmin)
117        value_xmax = self.floatInvTransform(xmax)
118        self.txtRangeMin.setText(formatNumber(value_xmin))
119        self.txtRangeMax.setText(formatNumber(value_xmax))
[570a58f9]120
[b46f285]121        tempx, tempy, tempdy = self.origData()
[570a58f9]122
123        # Find the fitting parameters
[dc5ef15]124        self.cstA = Fittings.Parameter(self.model, 'A', self.default_A)
125        self.cstB = Fittings.Parameter(self.model, 'B', self.default_B)
[570a58f9]126        tempdy = numpy.asarray(tempdy)
127        tempdy[tempdy == 0] = 1
128
[b46f285]129        if self.x_is_log:
130            xmin = numpy.log10(xmin)
131            xmax = numpy.log10(xmax)
132
[dc5ef15]133        chisqr, out, cov = Fittings.sasfit(self.model,
[b46f285]134                                           [self.cstA, self.cstB],
135                                           tempx, tempy, tempdy,
136                                           xmin, xmax)
[570a58f9]137        # Use chi2/dof
138        if len(tempx) > 0:
139            chisqr = chisqr / len(tempx)
140
141        # Check that cov and out are iterable before displaying them
142        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
143        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
144        cstA = out[0] if out is not None else 0.0
145        cstB = out[1] if out is not None else 0.0
146
147        # Reset model with the right values of A and B
148        self.model.setParam('A', float(cstA))
149        self.model.setParam('B', float(cstB))
150
151        tempx = []
152        tempy = []
153        y_model = 0.0
154
[b46f285]155        # load tempy with the minimum transformation
156        y_model = self.model.run(xmin)
157        tempx.append(xminView)
158        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model)
[570a58f9]159
160        # load tempy with the maximum transformation
[b46f285]161        y_model = self.model.run(xmax)
162        tempx.append(xmaxView)
163        tempy.append(numpy.power(10, y_model) if self.y_is_log else y_model)
[570a58f9]164
165        # Set the fit parameter display when  FitDialog is opened again
166        self.Avalue = cstA
167        self.Bvalue = cstB
168        self.ErrAvalue = errA
169        self.ErrBvalue = errB
170        self.Chivalue = chisqr
171
172        # Update the widget
173        self.txtA.setText(formatNumber(self.Avalue))
174        self.txtAerr.setText(formatNumber(self.ErrAvalue))
175        self.txtB.setText(formatNumber(self.Bvalue))
176        self.txtBerr.setText(formatNumber(self.ErrBvalue))
177        self.txtChi2.setText(formatNumber(self.Chivalue))
178
179        #self.parent.updatePlot.emit((tempx, tempy))
180        self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy))
181
[b46f285]182    def origData(self):
183        # Store the transformed values of view x, y and dy before the fit
184        xmin_check = numpy.log10(self.xminFit)
185        # Local shortcuts
186        x = self.data.view.x
187        y = self.data.view.y
188        dy = self.data.view.dy
189
190        if self.y_is_log:
191            if self.x_is_log:
192                tempy  = [numpy.log10(y[i])
193                         for i in range(len(x)) if x[i] >= xmin_check]
[dc5ef15]194                tempdy = [DataTransform.errToLogX(y[i], 0, dy[i], 0)
[b46f285]195                         for i in range(len(x)) if x[i] >= xmin_check]
196            else:
197                tempy = map(numpy.log10, y)
[dc5ef15]198                tempdy = map(lambda t1,t2:DataTransform.errToLogX(t1,0,t2,0),y,dy)
[b46f285]199        else:
200            tempy = y
201            tempdy = dy
202
203        if self.x_is_log:
204            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
205        else:
206            tempx = x
207
208        return tempx, tempy, tempdy
209
[570a58f9]210    def checkFitValues(self, item):
211        """
212        Check the validity of input values
213        """
214        flag = True
215        value = item.text()
216        p_white = item.palette()
217        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
218        p_pink = item.palette()
219        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
[b46f285]220        item.setAutoFillBackground(True)
[570a58f9]221        # Check for possible values entered
[b46f285]222        if self.x_is_log:
[570a58f9]223            if float(value) > 0:
224                item.setPalette(p_white)
225            else:
226                flag = False
227                item.setPalette(p_pink)
228        return flag
229
230    def floatInvTransform(self, x):
231        """
232        transform a float.It is used to determine the x.View min and x.View
233        max for values not in x.  Also used to properly calculate RgQmin,
234        RgQmax and to update qmin and qmax in the linear range boxes on the
235        panel.
236
237        """
238        # TODO: refactor this. This is just a hack to make the
239        # functionality work without rewritting the whole code
240        # with good design (which really should be done...).
241        if self.xLabel == "x":
242            return x
243        elif self.xLabel == "x^(2)":
244            return numpy.sqrt(x)
245        elif self.xLabel == "x^(4)":
[b46f285]246            return numpy.sqrt(numpy.sqrt(x))
[570a58f9]247        elif self.xLabel == "log10(x)":
248            return numpy.power(10, x)
249        elif self.xLabel == "ln(x)":
250            return numpy.exp(x)
251        elif self.xLabel == "log10(x^(4))":
252            return numpy.sqrt(numpy.sqrt(numpy.power(10, x)))
253        return x
[2e3e959]254
255
Note: See TracBrowser for help on using the repository browser.