source: sasview/src/sas/qtgui/Plotting/MaskEditor.py @ e20870bc

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since e20870bc was e20870bc, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Masking dialog for fitting

  • Property mode set to 100644
File size: 7.7 KB
Line 
1from functools import partial
2import copy
3import numpy as np
4
5from PyQt5 import QtWidgets
6
7from sas.qtgui.Plotting.PlotterData import Data2D
8
9# Local UI
10from sas.qtgui.UI import main_resources_rc
11from sas.qtgui.Plotting.UI.MaskEditorUI import Ui_MaskEditorUI
12from sas.qtgui.Plotting.Plotter2D import Plotter2DWidget
13
14from sas.qtgui.Plotting.Masks.SectorMask import SectorMask
15from sas.qtgui.Plotting.Masks.BoxMask import BoxMask
16from sas.qtgui.Plotting.Masks.CircularMask import CircularMask
17
18
19class MaskEditor(QtWidgets.QDialog, Ui_MaskEditorUI):
20    def __init__(self, parent=None, data=None):
21        super(MaskEditor, self).__init__()
22
23        assert isinstance(data, Data2D)
24
25        self.setupUi(self)
26
27        self.data = data
28        self.parent = parent
29        filename = data.name
30
31        self.current_slicer = None
32        self.slicer_mask = None
33
34        self.setWindowTitle("Mask Editor for %s" % filename)
35
36        self.plotter = Plotter2DWidget(self, manager=parent, quickplot=True)
37        self.plotter.data = self.data
38        self.slicer_z = 0
39        self.default_mask = copy.deepcopy(data.mask)
40
41        layout = QtWidgets.QHBoxLayout()
42        layout.setContentsMargins(0, 0, 0, 0)
43        self.frame.setLayout(layout)
44        layout.addWidget(self.plotter)
45
46        self.plotter.plot()
47        self.subplot = self.plotter.ax
48
49        # update mask
50        self.updateMask(self.default_mask)
51
52        self.initializeSignals()
53
54    def initializeSignals(self):
55        """
56        Attach slots to signals from radio boxes
57        """
58        self.rbWings.toggled.connect(partial(self.onMask, slicer=SectorMask, inside=True))
59        self.rbCircularDisk.toggled.connect(partial(self.onMask, slicer=CircularMask, inside=True))
60        self.rbRectangularDisk.toggled.connect(partial(self.onMask, slicer=BoxMask, inside=True))
61        self.rbDoubleWingWindow.toggled.connect(partial(self.onMask, slicer=SectorMask, inside=False))
62        self.rbCircularWindow.toggled.connect(partial(self.onMask, slicer=CircularMask, inside=False))
63        self.rbRectangularWindow.toggled.connect(partial(self.onMask, slicer=BoxMask, inside=False))
64
65        # Button groups defined so we can uncheck all buttons programmatically
66        self.buttonGroup = QtWidgets.QButtonGroup()
67        self.buttonGroup.addButton(self.rbWings)
68        self.buttonGroup.addButton(self.rbCircularDisk)
69        self.buttonGroup.addButton(self.rbRectangularDisk)
70        self.buttonGroup.addButton(self.rbDoubleWingWindow)
71        self.buttonGroup.addButton(self.rbCircularWindow)
72        self.buttonGroup.addButton(self.rbRectangularWindow)
73
74        # Push buttons
75        self.cmdAdd.clicked.connect(self.onAdd)
76        self.cmdReset.clicked.connect(self.onReset)
77        self.cmdClear.clicked.connect(self.onClear)
78
79    def emptyRadioButtons(self):
80        """
81        Uncheck all buttons without them firing signals causing unnecessary slicer updates
82        """
83        self.buttonGroup.setExclusive(False)
84        self.rbWings.blockSignals(True)
85        self.rbWings.setChecked(False)
86        self.rbWings.blockSignals(False)
87
88        self.rbCircularDisk.blockSignals(True)
89        self.rbCircularDisk.setChecked(False)
90        self.rbCircularDisk.blockSignals(False)
91
92        self.rbRectangularDisk.blockSignals(True)
93        self.rbRectangularDisk.setChecked(False)
94        self.rbRectangularDisk.blockSignals(False)
95
96        self.rbDoubleWingWindow.blockSignals(True)
97        self.rbDoubleWingWindow.setChecked(False)
98        self.rbDoubleWingWindow.blockSignals(False)
99
100        self.rbCircularWindow.blockSignals(True)
101        self.rbCircularWindow.setChecked(False)
102        self.rbCircularWindow.blockSignals(False)
103
104        self.rbRectangularWindow.blockSignals(True)
105        self.rbRectangularWindow.setChecked(False)
106        self.rbRectangularWindow.blockSignals(False)
107        self.buttonGroup.setExclusive(True)
108
109    def setSlicer(self, slicer):
110        """
111        Clear the previous slicer and create a new one.
112        slicer: slicer class to create
113        """
114        # Clear current slicer
115        if self.current_slicer is not None:
116            self.current_slicer.clear()
117        # Create a new slicer
118        self.slicer_z += 1
119        self.current_slicer = slicer(self, self.ax, zorder=self.slicer_z)
120        self.ax.set_ylim(self.data.ymin, self.data.ymax)
121        self.ax.set_xlim(self.data.xmin, self.data.xmax)
122        # Draw slicer
123        self.figure.canvas.draw()
124        self.current_slicer.update()
125
126    def onMask(self, slicer=None, inside=True):
127        """
128        Clear the previous mask and create a new one.
129        """
130        self.clearSlicer()
131        # modifying data in-place
132        self.slicer_z += 1
133
134        self.current_slicer = slicer(self.plotter, self.plotter.ax, zorder=self.slicer_z, side=inside)
135
136        self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax)
137        self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax)
138
139        self.plotter.canvas.draw()
140
141        self.slicer_mask = self.current_slicer.update()
142
143    def update(self):
144        """
145        Redraw the canvas
146        """
147        self.plotter.draw()
148
149    def onAdd(self):
150        """
151        Generate required mask and modify underlying DATA
152        """
153        if self.current_slicer is None:
154            return
155        data = Data2D()
156        data = self.data
157        self.slicer_mask = self.current_slicer.update()
158        data.mask = self.data.mask & self.slicer_mask
159        self.updateMask(data.mask)
160        self.emptyRadioButtons()
161
162    def onClear(self):
163        """
164        Remove the current mask(s)
165        """
166        self.slicer_z += 1
167        self.clearSlicer()
168        self.current_slicer = BoxMask(self.plotter, self.plotter.ax,
169                                      zorder=self.slicer_z, side=True)
170        self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax)
171        self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax)
172
173        self.data.mask = copy.deepcopy(self.default_mask)
174        # update mask plot
175        self.updateMask(self.data.mask)
176        self.emptyRadioButtons()
177
178    def onReset(self):
179        """
180        Removes all the masks from data
181        """
182        self.slicer_z += 1
183        self.clearSlicer()
184        self.current_slicer = BoxMask(self.plotter, self.plotter.ax,
185                                      zorder=self.slicer_z, side=True)
186        self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax)
187        self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax)
188        mask = np.ones(len(self.data.mask), dtype=bool)
189        self.data.mask = mask
190        # update mask plot
191        self.updateMask(mask)
192        self.emptyRadioButtons()
193
194    def clearSlicer(self):
195        """
196        Clear the slicer on the plot
197        """
198        if self.current_slicer is None:
199            return
200
201        self.current_slicer.clear()
202        self.plotter.draw()
203        self.current_slicer = None
204
205    def updateMask(self, mask):
206        """
207        Respond to changes in masking
208        """
209        # the case of liitle numbers of True points
210        if len(mask[mask]) < 10 and self.data is not None:
211            self.data.mask = copy.deepcopy(self.mask)
212        else:
213            self.mask = mask
214        # make temperary data to plot
215        temp_mask = np.zeros(len(mask))
216        temp_data = copy.deepcopy(self.data)
217        # temp_data default is None
218        # This method is to distinguish between masked point and data point = 0.
219        temp_mask = temp_mask / temp_mask
220        temp_mask[mask] = temp_data.data[mask]
221
222        temp_data.data[mask == False] = temp_mask[mask == False]
223
224        if self.current_slicer is not None:
225            self.current_slicer.clear()
226            self.current_slicer = None
227
228        # modify imshow data
229        self.plotter.plot(data=temp_data)
230        self.plotter.draw()
231
232        self.subplot = self.plotter.ax
Note: See TracBrowser for help on using the repository browser.