[e20870bc] | 1 | from functools import partial |
---|
| 2 | import copy |
---|
| 3 | import numpy as np |
---|
| 4 | |
---|
[4992ff2] | 5 | from PyQt5 import QtWidgets |
---|
[416fa8f] | 6 | |
---|
[dc5ef15] | 7 | from sas.qtgui.Plotting.PlotterData import Data2D |
---|
[416fa8f] | 8 | |
---|
| 9 | # Local UI |
---|
[cd2cc745] | 10 | from sas.qtgui.UI import main_resources_rc |
---|
[83eb5208] | 11 | from sas.qtgui.Plotting.UI.MaskEditorUI import Ui_MaskEditorUI |
---|
| 12 | from sas.qtgui.Plotting.Plotter2D import Plotter2DWidget |
---|
[416fa8f] | 13 | |
---|
[e20870bc] | 14 | from sas.qtgui.Plotting.Masks.SectorMask import SectorMask |
---|
| 15 | from sas.qtgui.Plotting.Masks.BoxMask import BoxMask |
---|
| 16 | from sas.qtgui.Plotting.Masks.CircularMask import CircularMask |
---|
| 17 | |
---|
| 18 | |
---|
[4992ff2] | 19 | class MaskEditor(QtWidgets.QDialog, Ui_MaskEditorUI): |
---|
[416fa8f] | 20 | def __init__(self, parent=None, data=None): |
---|
| 21 | super(MaskEditor, self).__init__() |
---|
| 22 | |
---|
[e20870bc] | 23 | assert isinstance(data, Data2D) |
---|
[416fa8f] | 24 | |
---|
| 25 | self.setupUi(self) |
---|
| 26 | |
---|
| 27 | self.data = data |
---|
[e20870bc] | 28 | self.parent = parent |
---|
[416fa8f] | 29 | filename = data.name |
---|
[e20870bc] | 30 | |
---|
| 31 | self.current_slicer = None |
---|
| 32 | self.slicer_mask = None |
---|
| 33 | |
---|
[416fa8f] | 34 | self.setWindowTitle("Mask Editor for %s" % filename) |
---|
| 35 | |
---|
| 36 | self.plotter = Plotter2DWidget(self, manager=parent, quickplot=True) |
---|
| 37 | self.plotter.data = self.data |
---|
[e20870bc] | 38 | self.slicer_z = 0 |
---|
| 39 | self.default_mask = copy.deepcopy(data.mask) |
---|
[416fa8f] | 40 | |
---|
[4992ff2] | 41 | layout = QtWidgets.QHBoxLayout() |
---|
[416fa8f] | 42 | layout.setContentsMargins(0, 0, 0, 0) |
---|
| 43 | self.frame.setLayout(layout) |
---|
| 44 | |
---|
| 45 | self.plotter.plot() |
---|
[a0ad146] | 46 | layout.addWidget(self.plotter) |
---|
[e20870bc] | 47 | self.subplot = self.plotter.ax |
---|
| 48 | |
---|
| 49 | # update mask |
---|
| 50 | self.updateMask(self.default_mask) |
---|
| 51 | |
---|
| 52 | self.initializeSignals() |
---|
| 53 | |
---|
| 54 | def initializeSignals(self): |
---|
| 55 | """ |
---|
| 56 | Attach slots to signals from radio boxes |
---|
| 57 | """ |
---|
| 58 | self.rbWings.toggled.connect(partial(self.onMask, slicer=SectorMask, inside=True)) |
---|
| 59 | self.rbCircularDisk.toggled.connect(partial(self.onMask, slicer=CircularMask, inside=True)) |
---|
| 60 | self.rbRectangularDisk.toggled.connect(partial(self.onMask, slicer=BoxMask, inside=True)) |
---|
| 61 | self.rbDoubleWingWindow.toggled.connect(partial(self.onMask, slicer=SectorMask, inside=False)) |
---|
| 62 | self.rbCircularWindow.toggled.connect(partial(self.onMask, slicer=CircularMask, inside=False)) |
---|
| 63 | self.rbRectangularWindow.toggled.connect(partial(self.onMask, slicer=BoxMask, inside=False)) |
---|
| 64 | |
---|
| 65 | # Button groups defined so we can uncheck all buttons programmatically |
---|
| 66 | self.buttonGroup = QtWidgets.QButtonGroup() |
---|
| 67 | self.buttonGroup.addButton(self.rbWings) |
---|
| 68 | self.buttonGroup.addButton(self.rbCircularDisk) |
---|
| 69 | self.buttonGroup.addButton(self.rbRectangularDisk) |
---|
| 70 | self.buttonGroup.addButton(self.rbDoubleWingWindow) |
---|
| 71 | self.buttonGroup.addButton(self.rbCircularWindow) |
---|
| 72 | self.buttonGroup.addButton(self.rbRectangularWindow) |
---|
| 73 | |
---|
| 74 | # Push buttons |
---|
| 75 | self.cmdAdd.clicked.connect(self.onAdd) |
---|
| 76 | self.cmdReset.clicked.connect(self.onReset) |
---|
| 77 | self.cmdClear.clicked.connect(self.onClear) |
---|
| 78 | |
---|
| 79 | def emptyRadioButtons(self): |
---|
| 80 | """ |
---|
| 81 | Uncheck all buttons without them firing signals causing unnecessary slicer updates |
---|
| 82 | """ |
---|
| 83 | self.buttonGroup.setExclusive(False) |
---|
| 84 | self.rbWings.blockSignals(True) |
---|
| 85 | self.rbWings.setChecked(False) |
---|
| 86 | self.rbWings.blockSignals(False) |
---|
| 87 | |
---|
| 88 | self.rbCircularDisk.blockSignals(True) |
---|
| 89 | self.rbCircularDisk.setChecked(False) |
---|
| 90 | self.rbCircularDisk.blockSignals(False) |
---|
| 91 | |
---|
| 92 | self.rbRectangularDisk.blockSignals(True) |
---|
| 93 | self.rbRectangularDisk.setChecked(False) |
---|
| 94 | self.rbRectangularDisk.blockSignals(False) |
---|
| 95 | |
---|
| 96 | self.rbDoubleWingWindow.blockSignals(True) |
---|
| 97 | self.rbDoubleWingWindow.setChecked(False) |
---|
| 98 | self.rbDoubleWingWindow.blockSignals(False) |
---|
| 99 | |
---|
| 100 | self.rbCircularWindow.blockSignals(True) |
---|
| 101 | self.rbCircularWindow.setChecked(False) |
---|
| 102 | self.rbCircularWindow.blockSignals(False) |
---|
| 103 | |
---|
| 104 | self.rbRectangularWindow.blockSignals(True) |
---|
| 105 | self.rbRectangularWindow.setChecked(False) |
---|
| 106 | self.rbRectangularWindow.blockSignals(False) |
---|
| 107 | self.buttonGroup.setExclusive(True) |
---|
| 108 | |
---|
| 109 | def setSlicer(self, slicer): |
---|
| 110 | """ |
---|
| 111 | Clear the previous slicer and create a new one. |
---|
| 112 | slicer: slicer class to create |
---|
| 113 | """ |
---|
| 114 | # Clear current slicer |
---|
| 115 | if self.current_slicer is not None: |
---|
| 116 | self.current_slicer.clear() |
---|
| 117 | # Create a new slicer |
---|
| 118 | self.slicer_z += 1 |
---|
| 119 | self.current_slicer = slicer(self, self.ax, zorder=self.slicer_z) |
---|
| 120 | self.ax.set_ylim(self.data.ymin, self.data.ymax) |
---|
| 121 | self.ax.set_xlim(self.data.xmin, self.data.xmax) |
---|
| 122 | # Draw slicer |
---|
| 123 | self.figure.canvas.draw() |
---|
| 124 | self.current_slicer.update() |
---|
| 125 | |
---|
| 126 | def onMask(self, slicer=None, inside=True): |
---|
| 127 | """ |
---|
| 128 | Clear the previous mask and create a new one. |
---|
| 129 | """ |
---|
| 130 | self.clearSlicer() |
---|
| 131 | # modifying data in-place |
---|
| 132 | self.slicer_z += 1 |
---|
| 133 | |
---|
| 134 | self.current_slicer = slicer(self.plotter, self.plotter.ax, zorder=self.slicer_z, side=inside) |
---|
| 135 | |
---|
| 136 | self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax) |
---|
| 137 | self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax) |
---|
| 138 | |
---|
| 139 | self.plotter.canvas.draw() |
---|
| 140 | |
---|
| 141 | self.slicer_mask = self.current_slicer.update() |
---|
| 142 | |
---|
| 143 | def update(self): |
---|
| 144 | """ |
---|
| 145 | Redraw the canvas |
---|
| 146 | """ |
---|
| 147 | self.plotter.draw() |
---|
| 148 | |
---|
| 149 | def onAdd(self): |
---|
| 150 | """ |
---|
| 151 | Generate required mask and modify underlying DATA |
---|
| 152 | """ |
---|
| 153 | if self.current_slicer is None: |
---|
| 154 | return |
---|
| 155 | data = Data2D() |
---|
| 156 | data = self.data |
---|
| 157 | self.slicer_mask = self.current_slicer.update() |
---|
| 158 | data.mask = self.data.mask & self.slicer_mask |
---|
| 159 | self.updateMask(data.mask) |
---|
| 160 | self.emptyRadioButtons() |
---|
| 161 | |
---|
| 162 | def onClear(self): |
---|
| 163 | """ |
---|
| 164 | Remove the current mask(s) |
---|
| 165 | """ |
---|
| 166 | self.slicer_z += 1 |
---|
| 167 | self.clearSlicer() |
---|
| 168 | self.current_slicer = BoxMask(self.plotter, self.plotter.ax, |
---|
| 169 | zorder=self.slicer_z, side=True) |
---|
| 170 | self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax) |
---|
| 171 | self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax) |
---|
| 172 | |
---|
| 173 | self.data.mask = copy.deepcopy(self.default_mask) |
---|
| 174 | # update mask plot |
---|
| 175 | self.updateMask(self.data.mask) |
---|
| 176 | self.emptyRadioButtons() |
---|
| 177 | |
---|
| 178 | def onReset(self): |
---|
| 179 | """ |
---|
| 180 | Removes all the masks from data |
---|
| 181 | """ |
---|
| 182 | self.slicer_z += 1 |
---|
| 183 | self.clearSlicer() |
---|
| 184 | self.current_slicer = BoxMask(self.plotter, self.plotter.ax, |
---|
| 185 | zorder=self.slicer_z, side=True) |
---|
| 186 | self.plotter.ax.set_ylim(self.data.ymin, self.data.ymax) |
---|
| 187 | self.plotter.ax.set_xlim(self.data.xmin, self.data.xmax) |
---|
| 188 | mask = np.ones(len(self.data.mask), dtype=bool) |
---|
| 189 | self.data.mask = mask |
---|
| 190 | # update mask plot |
---|
| 191 | self.updateMask(mask) |
---|
| 192 | self.emptyRadioButtons() |
---|
| 193 | |
---|
| 194 | def clearSlicer(self): |
---|
| 195 | """ |
---|
| 196 | Clear the slicer on the plot |
---|
| 197 | """ |
---|
| 198 | if self.current_slicer is None: |
---|
| 199 | return |
---|
| 200 | |
---|
| 201 | self.current_slicer.clear() |
---|
| 202 | self.plotter.draw() |
---|
| 203 | self.current_slicer = None |
---|
| 204 | |
---|
| 205 | def updateMask(self, mask): |
---|
| 206 | """ |
---|
| 207 | Respond to changes in masking |
---|
| 208 | """ |
---|
[05fa132] | 209 | # the case of litle numbers of True points |
---|
[e20870bc] | 210 | if len(mask[mask]) < 10 and self.data is not None: |
---|
[05fa132] | 211 | self.data.mask = copy.deepcopy(self.default_mask) |
---|
[e20870bc] | 212 | else: |
---|
[05fa132] | 213 | self.default_mask = mask |
---|
[e20870bc] | 214 | # make temperary data to plot |
---|
| 215 | temp_mask = np.zeros(len(mask)) |
---|
| 216 | temp_data = copy.deepcopy(self.data) |
---|
| 217 | # temp_data default is None |
---|
| 218 | # This method is to distinguish between masked point and data point = 0. |
---|
| 219 | temp_mask = temp_mask / temp_mask |
---|
| 220 | temp_mask[mask] = temp_data.data[mask] |
---|
| 221 | |
---|
| 222 | temp_data.data[mask == False] = temp_mask[mask == False] |
---|
| 223 | |
---|
| 224 | if self.current_slicer is not None: |
---|
| 225 | self.current_slicer.clear() |
---|
| 226 | self.current_slicer = None |
---|
| 227 | |
---|
| 228 | # modify imshow data |
---|
[dce68f6] | 229 | self.plotter.plot(data=temp_data, update=True) |
---|
[e20870bc] | 230 | self.plotter.draw() |
---|
[416fa8f] | 231 | |
---|
[e20870bc] | 232 | self.subplot = self.plotter.ax |
---|