source: sasview/src/sas/qtgui/LinearFit.py @ 570a58f9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 570a58f9 was 570a58f9, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Linear fits for 1D charts

  • Property mode set to 100755
File size: 9.7 KB
Line 
1"""
2Adds a linear fit plot to the chart
3"""
4import numpy
5from PyQt4 import QtGui
6from PyQt4 import QtCore
7
8#from sas.sasgui.transform import
9from sas.qtgui.GuiUtils import formatNumber
10from sas.sasgui.plottools import fittings
11from sas.sasgui.plottools import transform
12
13from sas.sasgui.plottools.LineModel import LineModel
14
15# Local UI
16from sas.qtgui.UI.LinearFitUI import Ui_LinearFitUI
17
18class LinearFit(QtGui.QDialog, Ui_LinearFitUI):
19    def __init__(self, parent=None,
20                 data=None,
21                 max_range=(0.0, 0.0),
22                 fit_range=(0.0, 0.0),
23                 xlabel="",
24                 ylabel=""):
25        super(LinearFit, self).__init__()
26
27        self.setupUi(self)
28        assert(isinstance(max_range, tuple))
29        assert(isinstance(fit_range, tuple))
30
31        self.data = data
32        self.parent = parent
33
34        self.max_range = max_range
35        self.fit_range = fit_range
36        self.xLabel = xlabel
37        self.yLabel = ylabel
38
39        self.txtFitRangeMin.setValidator(QtGui.QDoubleValidator())
40        self.txtFitRangeMax.setValidator(QtGui.QDoubleValidator())
41
42        # Default values in the line edits
43        self.txtA.setText("1")
44        self.txtB.setText("1")
45        self.txtAerr.setText("0")
46        self.txtBerr.setText("0")
47        self.txtChi2.setText("0")
48        # Initial ranges
49        self.txtRangeMin.setText(str(max_range[0]))
50        self.txtRangeMax.setText(str(max_range[1]))
51        self.txtFitRangeMin.setText(str(fit_range[0]))
52        self.txtFitRangeMax.setText(str(fit_range[1]))
53
54        self.lblRange.setText('Fit range of ' + self.xLabel)
55
56        self.model = LineModel()
57        # Display the fittings values
58        self.default_A = self.model.getParam('A')
59        self.default_B = self.model.getParam('B')
60        self.cstA = fittings.Parameter(self.model, 'A', self.default_A)
61        self.cstB = fittings.Parameter(self.model, 'B', self.default_B)
62        self.transform = transform
63
64        self.cmdFit.clicked.connect(self.fit)
65
66    def setRangeLabel(self, label=""):
67        """
68        Overwrite default fit range label to correspond to actual unit
69        """
70        assert(isinstance(label, basestring))
71        self.lblRange.setText(label)
72
73    def a(self):
74        return (float(self.txtA.text()), float(self.txtAerr.text()))
75
76    def b(self):
77        return (float(self.txtB.text()), float(self.txtBerr.text()))
78
79    def chi(self):
80        return float(self.txtChi2.text())
81
82    def range(self):
83        return (float(self.txtFitRangeMin.text()), float(self.txtFitRangeMax.text()))
84
85    def fit(self, event):
86        """
87        Performs the fit. Receive an event when clicking on
88        the button Fit.Computes chisqr ,
89        A and B parameters of the best linear fit y=Ax +B
90        Push a plottable to the caller
91        """
92        tempx = []
93        tempy = []
94        tempdy = []
95
96        # Checks to assure data correctness
97        if len(self.data.view.x) < 2:
98            return
99        if not self.checkFitValues(self.txtFitRangeMin):
100            return
101
102        self.xminFit, self.xmaxFit = self.range()
103
104        xminView = self.xminFit
105        xmaxView = self.xmaxFit
106        xmin = xminView
107        xmax = xmaxView
108        # Set the qmin and qmax in the panel that matches the
109        # transformed min and max
110        self.txtRangeMin.setText(formatNumber(self.floatInvTransform(xmin)))
111        self.txtRangeMax.setText(formatNumber(self.floatInvTransform(xmax)))
112
113        # Store the transformed values of view x, y and dy before the fit
114        xmin_check = numpy.log10(xmin)
115        x = self.data.view.x
116        y = self.data.view.y
117        dy = self.data.view.dy
118
119        if self.yLabel == "log10(y)":
120            if self.xLabel == "log10(x)":
121                tempy  = [numpy.log10(y[i])
122                         for i in range(len(x)) if x[i] >= xmin_check]
123                tempdy = [transform.errToLogX(y[i], 0, dy[i], 0)
124                         for i in range(len(x)) if x[i] >= xmin_check]
125            else:
126                tempy = map(numpy.log10, y)
127                tempdy = map(lambda t1,t2:transform.errToLogX(t1,0,t2,0),y,dy)
128        else:
129            tempy = y
130            tempdy = dy
131
132        if self.xLabel == "log10(x)":
133            tempx = [numpy.log10(x) for x in self.data.view.x if x > xmin_check]
134        else:
135            tempx = self.data.view.x
136
137        # Find the fitting parameters
138        # Always use the same defaults, so that fit history
139        # doesn't play a role!
140        self.cstA = fittings.Parameter(self.model, 'A', self.default_A)
141        self.cstB = fittings.Parameter(self.model, 'B', self.default_B)
142        tempdy = numpy.asarray(tempdy)
143        tempdy[tempdy == 0] = 1
144
145        if self.xLabel == "log10(x)":
146            chisqr, out, cov = fittings.sasfit(self.model,
147                                               [self.cstA, self.cstB],
148                                               tempx, tempy,
149                                               tempdy,
150                                               numpy.log10(xmin),
151                                               numpy.log10(xmax))
152        else:
153            chisqr, out, cov = fittings.sasfit(self.model,
154                                               [self.cstA, self.cstB],
155                                               tempx, tempy, tempdy,
156                                               xminView, xmaxView)
157        # Use chi2/dof
158        if len(tempx) > 0:
159            chisqr = chisqr / len(tempx)
160
161        # Check that cov and out are iterable before displaying them
162        errA = numpy.sqrt(cov[0][0]) if cov is not None else 0
163        errB = numpy.sqrt(cov[1][1]) if cov is not None else 0
164        cstA = out[0] if out is not None else 0.0
165        cstB = out[1] if out is not None else 0.0
166
167        # Reset model with the right values of A and B
168        self.model.setParam('A', float(cstA))
169        self.model.setParam('B', float(cstB))
170
171        tempx = []
172        tempy = []
173        y_model = 0.0
174        # load tempy with the minimum transformation
175        if self.xLabel == "log10(x)":
176            y_model = self.model.run(numpy.log10(xmin))
177            tempx.append(xmin)
178        else:
179            y_model = self.model.run(xminView)
180            tempx.append(xminView)
181
182        if self.yLabel == "log10(y)":
183            tempy.append(numpy.power(10, y_model))
184        else:
185            tempy.append(y_model)
186
187        # load tempy with the maximum transformation
188        if self.xLabel == "log10(x)":
189            y_model = self.model.run(numpy.log10(xmax))
190            tempx.append(xmax)
191        else:
192            y_model = self.model.run(xmaxView)
193            tempx.append(xmaxView)
194
195        if self.yLabel == "log10(y)":
196            tempy.append(numpy.power(10, y_model))
197        else:
198            tempy.append(y_model)
199        # Set the fit parameter display when  FitDialog is opened again
200        self.Avalue = cstA
201        self.Bvalue = cstB
202        self.ErrAvalue = errA
203        self.ErrBvalue = errB
204        self.Chivalue = chisqr
205
206        # Update the widget
207        self.txtA.setText(formatNumber(self.Avalue))
208        self.txtAerr.setText(formatNumber(self.ErrAvalue))
209        self.txtB.setText(formatNumber(self.Bvalue))
210        self.txtBerr.setText(formatNumber(self.ErrBvalue))
211        self.txtChi2.setText(formatNumber(self.Chivalue))
212
213        #self.parent.updatePlot.emit((tempx, tempy))
214        self.parent.emit(QtCore.SIGNAL('updatePlot'), (tempx, tempy))
215
216    def checkFitValues(self, item):
217        """
218        Check the validity of input values
219        """
220        flag = True
221        value = item.text()
222        p_white = item.palette()
223        p_white.setColor(item.backgroundRole(), QtCore.Qt.white)
224        p_pink = item.palette()
225        p_pink.setColor(item.backgroundRole(), QtGui.QColor(255, 128, 128))
226        # Check for possible values entered
227        if self.xLabel == "log10(x)":
228            if float(value) > 0:
229                item.setPalette(p_white)
230            else:
231                flag = False
232                item.setPalette(p_pink)
233        return flag
234
235    #def checkFitRanges(self, usermin, usermax):
236    #    """
237    #    Ensure that fields parameter contains a min and a max value
238    #    within x min and x max range
239    #    """
240    #    self.mini = float(self.xminFit.GetValue())
241    #    self.maxi = float(self.xmaxFit.GetValue())
242    #    flag = True
243    #    try:
244    #        mini = float(usermin.GetValue())
245    #        maxi = float(usermax.GetValue())
246    #        if mini < maxi:
247    #            usermin.SetBackgroundColour(wx.WHITE)
248    #            usermin.Refresh()
249    #        else:
250    #            flag = False
251    #            usermin.SetBackgroundColour("pink")
252    #            usermin.Refresh()
253    #    except:
254    #        # Check for possible values entered
255    #        flag = False
256    #        usermin.SetBackgroundColour("pink")
257    #        usermin.Refresh()
258
259    #    return flag
260    def floatInvTransform(self, x):
261        """
262        transform a float.It is used to determine the x.View min and x.View
263        max for values not in x.  Also used to properly calculate RgQmin,
264        RgQmax and to update qmin and qmax in the linear range boxes on the
265        panel.
266
267        """
268        # TODO: refactor this. This is just a hack to make the
269        # functionality work without rewritting the whole code
270        # with good design (which really should be done...).
271        if self.xLabel == "x":
272            return x
273        elif self.xLabel == "x^(2)":
274            return numpy.sqrt(x)
275        elif self.xLabel == "x^(4)":
276            return numpy.sqrt(math.sqrt(x))
277        elif self.xLabel == "log10(x)":
278            return numpy.power(10, x)
279        elif self.xLabel == "ln(x)":
280            return numpy.exp(x)
281        elif self.xLabel == "log10(x^(4))":
282            return numpy.sqrt(numpy.sqrt(numpy.power(10, x)))
283        return x
Note: See TracBrowser for help on using the repository browser.