source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py

ESS_GUI
Last change on this file was 33c0561, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 5 years ago

Replace Apply button menu driven functionality with additional button.
Removed Cancel.
Removed the window system context help button from all affected widgets.
SASVIEW-1239

  • Property mode set to 100644
File size: 34.3 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34    calculationFinishedSignal = QtCore.pyqtSignal()
35    loadingFinishedSignal = QtCore.pyqtSignal(list)
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        # disable the context help icon
41        self.setWindowFlags(self.windowFlags() & ~QtCore.Qt.WindowContextHelpButtonHint)
42
43        self.manager = parent
44        self.communicator = self.manager.communicator()
45        self.model = sas_gen.GenSAS()
46        self.omf_reader = sas_gen.OMFReader()
47        self.sld_reader = sas_gen.SLDReader()
48        self.pdb_reader = sas_gen.PDBReader()
49        self.reader = None
50        self.sld_data = None
51
52        self.parameters = []
53        self.data = None
54        self.datafile = None
55        self.file_name = ''
56        self.ext = None
57        self.default_shape = str(self.cbShape.currentText())
58        self.is_avg = False
59        self.data_to_plot = None
60        self.graph_num = 1  # index for name of graph
61
62        # combox box
63        self.cbOptionsCalc.setVisible(False)
64
65        # push buttons
66        self.cmdClose.clicked.connect(self.accept)
67        self.cmdHelp.clicked.connect(self.onHelp)
68
69        self.cmdLoad.clicked.connect(self.loadFile)
70        self.cmdCompute.clicked.connect(self.onCompute)
71        self.cmdReset.clicked.connect(self.onReset)
72        self.cmdSave.clicked.connect(self.onSaveFile)
73
74        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
75        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
76
77        # validators
78        # scale, volume and background must be positive
79        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
80        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
81                                                          self.txtScale))
82        self.txtBackground.setValidator(QtGui.QRegExpValidator(
83            validat_regex_pos, self.txtBackground))
84        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
85            validat_regex_pos, self.txtTotalVolume))
86
87        # fraction of spin up between 0 and 1
88        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
89        self.txtUpFracIn.setValidator(
90            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
91        self.txtUpFracOut.setValidator(
92            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
93
94        # 0 < Qmax <= 1000
95        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
96        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
97                                                          self.txtQxMax))
98        self.txtQxMax.textChanged.connect(self.check_value)
99
100        # 2 <= Qbin <= 1000
101        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
102                                                            self.txtNoQBins))
103        self.txtNoQBins.textChanged.connect(self.check_value)
104
105        # plots - 3D in real space
106        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
107
108        # plots - 3D in real space
109        self.calculationFinishedSignal.connect(self.plot_1_2d)
110
111        # notify main thread about file load complete
112        self.loadingFinishedSignal.connect(self.complete_loading)
113
114        # TODO the option Ellipsoid has not been implemented
115        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
116
117        # New font to display angstrom symbol
118        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
119        self.lblUnitSolventSLD.setStyleSheet(new_font)
120        self.lblUnitVolume.setStyleSheet(new_font)
121        self.lbl5.setStyleSheet(new_font)
122        self.lblUnitMx.setStyleSheet(new_font)
123        self.lblUnitMy.setStyleSheet(new_font)
124        self.lblUnitMz.setStyleSheet(new_font)
125        self.lblUnitNucl.setStyleSheet(new_font)
126        self.lblUnitx.setStyleSheet(new_font)
127        self.lblUnity.setStyleSheet(new_font)
128        self.lblUnitz.setStyleSheet(new_font)
129
130    def selectedshapechange(self):
131        """
132        TODO Temporary solution to display information about option 'Ellipsoid'
133        """
134        print("The option Ellipsoid has not been implemented yet.")
135        self.communicator.statusBarUpdateSignal.emit(
136            "The option Ellipsoid has not been implemented yet.")
137
138    def loadFile(self):
139        """
140        Open menu to choose the datafile to load
141        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
142        """
143        try:
144            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
145                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
146                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
147                                          "OMF files (*.OMF *.omf);; "
148                                          "All files (*.*)")[0]
149            if self.datafile:
150                self.default_shape = str(self.cbShape.currentText())
151                self.file_name = os.path.basename(str(self.datafile))
152                self.ext = os.path.splitext(str(self.datafile))[1]
153                if self.ext in self.omf_reader.ext:
154                    loader = self.omf_reader
155                elif self.ext in self.sld_reader.ext:
156                    loader = self.sld_reader
157                elif self.ext in self.pdb_reader.ext:
158                    loader = self.pdb_reader
159                else:
160                    loader = None
161                # disable some entries depending on type of loaded file
162                # (according to documentation)
163                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
164                    self.txtUpFracIn.setEnabled(False)
165                    self.txtUpFracOut.setEnabled(False)
166                    self.txtUpTheta.setEnabled(False)
167
168                if self.reader is not None and self.reader.isrunning():
169                    self.reader.stop()
170                self.cmdLoad.setEnabled(False)
171                self.cmdLoad.setText('Loading...')
172                self.communicator.statusBarUpdateSignal.emit(
173                    "Loading File {}".format(os.path.basename(
174                        str(self.datafile))))
175                self.reader = GenReader(path=str(self.datafile), loader=loader,
176                                        completefn=self.complete_loading_ex,
177                                        updatefn=self.load_update)
178                self.reader.queue()
179        except (RuntimeError, IOError):
180            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
181            logging.info(log_msg)
182            raise
183        return
184
185    def load_update(self):
186        """ Legacy function used in GenRead """
187        if self.reader.isrunning():
188            status_type = "progress"
189        else:
190            status_type = "stop"
191        logging.info(status_type)
192
193    def complete_loading_ex(self, data=None):
194        """
195        Send the finish message from calculate threads to main thread
196        """
197        self.loadingFinishedSignal.emit(data)
198
199    def complete_loading(self, data=None):
200        """ Function used in GenRead"""
201        assert isinstance(data, list)
202        assert len(data)==1
203        data = data[0]
204        self.cbShape.setEnabled(False)
205        try:
206            is_pdbdata = False
207            self.txtData.setText(os.path.basename(str(self.datafile)))
208            self.is_avg = False
209            if self.ext in self.omf_reader.ext:
210                gen = sas_gen.OMF2SLD()
211                gen.set_data(data)
212                self.sld_data = gen.get_magsld()
213                self.check_units()
214            elif self.ext in self.sld_reader.ext:
215                self.sld_data = data
216            elif self.ext in self.pdb_reader.ext:
217                self.sld_data = data
218                is_pdbdata = True
219            # Display combobox of orientation only for pdb data
220            self.cbOptionsCalc.setVisible(is_pdbdata)
221            self.update_gui()
222        except IOError:
223            log_msg = "Loading Error: " \
224                      "This file format is not supported for GenSAS."
225            logging.info(log_msg)
226            raise
227        except ValueError:
228            log_msg = "Could not find any data"
229            logging.info(log_msg)
230            raise
231        logging.info("Load Complete")
232        self.cmdLoad.setEnabled(True)
233        self.cmdLoad.setText('Load')
234        self.trigger_plot_3d.emit()
235
236    def check_units(self):
237        """
238        Check if the units from the OMF file correspond to the default ones
239        displayed on the interface.
240        If not, modify the GUI with the correct unit
241        """
242        #  TODO: adopt the convention of font and symbol for the updated values
243        if sas_gen.OMFData().valueunit != 'A^(-2)':
244            value_unit = sas_gen.OMFData().valueunit
245            self.lbl_unitMx.setText(value_unit)
246            self.lbl_unitMy.setText(value_unit)
247            self.lbl_unitMz.setText(value_unit)
248            self.lbl_unitNucl.setText(value_unit)
249        if sas_gen.OMFData().meshunit != 'A':
250            mesh_unit = sas_gen.OMFData().meshunit
251            self.lbl_unitx.setText(mesh_unit)
252            self.lbl_unity.setText(mesh_unit)
253            self.lbl_unitz.setText(mesh_unit)
254            self.lbl_unitVolume.setText(mesh_unit+"^3")
255
256    def check_value(self):
257        """Check range of text edits for QMax and Number of Qbins """
258        text_edit = self.sender()
259        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
260        if text_edit.text():
261            value = float(str(text_edit.text()))
262            if text_edit == self.txtQxMax:
263                if value <= 0 or value > 1000:
264                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
265                else:
266                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
267            elif text_edit == self.txtNoQBins:
268                if value < 2 or value > 1000:
269                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
270                else:
271                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
272
273    def update_gui(self):
274        """ Update the interface with values from loaded data """
275        self.model.set_is_avg(self.is_avg)
276        self.model.set_sld_data(self.sld_data)
277        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
278
279        # add condition for activation of save button
280        self.cmdSave.setEnabled(True)
281
282        # activation of 3D plots' buttons (with and without arrows)
283        self.cmdDraw.setEnabled(self.sld_data is not None)
284        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
285
286        self.txtScale.setText(str(self.model.params['scale']))
287        self.txtBackground.setText(str(self.model.params['background']))
288        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
289
290        # Volume to write to interface: npts x volume of first pixel
291        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
292
293        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
294        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
295        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
296
297        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
298        self.txtNoPixels.setEnabled(False)
299
300        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
301                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
302                           'zstepsize']
303        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
304                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
305                           self.txtXstepsize, self.txtYstepsize,
306                           self.txtZstepsize]
307
308        # Fill right hand side of GUI
309        for indx, item in enumerate(list_parameters):
310            if getattr(self.sld_data, item) is None:
311                list_gui_button[indx].setText('NaN')
312            else:
313                value = getattr(self.sld_data, item)
314                if isinstance(value, numpy.ndarray):
315                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
316                else:
317                    item_for_gui = str(GuiUtils.formatNumber(value, True))
318                list_gui_button[indx].setText(item_for_gui)
319
320        # Enable / disable editing of right hand side of GUI
321        for indx, item in enumerate(list_parameters):
322            if indx < 4:
323                # this condition only applies to Mx,y,z and Nucl
324                value = getattr(self.sld_data, item)
325                enable = self.sld_data.pix_type == 'pixel' \
326                         and numpy.min(value) == numpy.max(value)
327            else:
328                enable = not self.sld_data.is_data
329            list_gui_button[indx].setEnabled(enable)
330
331    def write_new_values_from_gui(self):
332        """
333        update parameters using modified inputs from GUI
334        used before computing
335        """
336        if self.txtScale.isModified():
337            self.model.params['scale'] = float(self.txtScale.text())
338
339        if self.txtBackground.isModified():
340            self.model.params['background'] = float(self.txtBackground.text())
341
342        if self.txtSolventSLD.isModified():
343            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
344
345        # Different condition for total volume to get correct volume after
346        # applying set_sld_data in compute
347        if self.txtTotalVolume.isModified() \
348                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
349            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
350
351        if self.txtUpFracIn.isModified():
352            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
353
354        if self.txtUpFracOut.isModified():
355            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
356
357        if self.txtUpTheta.isModified():
358            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
359
360        if self.txtMx.isModified():
361            self.sld_data.sld_mx = float(self.txtMx.text())*\
362                                   numpy.ones(len(self.sld_data.sld_mx))
363
364        if self.txtMy.isModified():
365            self.sld_data.sld_my = float(self.txtMy.text())*\
366                                   numpy.ones(len(self.sld_data.sld_my))
367
368        if self.txtMz.isModified():
369            self.sld_data.sld_mz = float(self.txtMz.text())*\
370                                   numpy.ones(len(self.sld_data.sld_mz))
371
372        if self.txtNucl.isModified():
373            self.sld_data.sld_n = float(self.txtNucl.text())*\
374                                  numpy.ones(len(self.sld_data.sld_n))
375
376        if self.txtXnodes.isModified():
377            self.sld_data.xnodes = int(self.txtXnodes.text())
378
379        if self.txtYnodes.isModified():
380            self.sld_data.ynodes = int(self.txtYnodes.text())
381
382        if self.txtZnodes.isModified():
383            self.sld_data.znodes = int(self.txtZnodes.text())
384
385        if self.txtXstepsize.isModified():
386            self.sld_data.xstepsize = float(self.txtXstepsize.text())
387
388        if self.txtYstepsize.isModified():
389            self.sld_data.ystepsize = float(self.txtYstepsize.text())
390
391        if self.txtZstepsize.isModified():
392            self.sld_data.zstepsize = float(self.txtZstepsize.text())
393
394        if self.cbOptionsCalc.isVisible():
395            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
396
397    def onHelp(self):
398        """
399        Bring up the Generic Scattering calculator Documentation whenever
400        the HELP button is clicked.
401        Calls Documentation Window with the path of the location within the
402        documentation tree (after /doc/ ....".
403        """
404        location = "/user/qtgui/Calculators/sas_calculator_help.html"
405        self.manager.showHelp(location)
406
407    def onReset(self):
408        """ Reset the inputs of textEdit to default values """
409        try:
410            # reset values in textedits
411            self.txtUpFracIn.setText("1.0")
412            self.txtUpFracOut.setText("1.0")
413            self.txtUpTheta.setText("0.0")
414            self.txtBackground.setText("0.0")
415            self.txtScale.setText("1.0")
416            self.txtSolventSLD.setText("0.0")
417            self.txtTotalVolume.setText("216000.0")
418            self.txtNoQBins.setText("50")
419            self.txtQxMax.setText("0.3")
420            self.txtNoPixels.setText("1000")
421            self.txtMx.setText("0")
422            self.txtMy.setText("0")
423            self.txtMz.setText("0")
424            self.txtNucl.setText("6.97e-06")
425            self.txtXnodes.setText("10")
426            self.txtYnodes.setText("10")
427            self.txtZnodes.setText("10")
428            self.txtXstepsize.setText("6")
429            self.txtYstepsize.setText("6")
430            self.txtZstepsize.setText("6")
431            # reset Load button and textedit
432            self.txtData.setText('Default SLD Profile')
433            self.cmdLoad.setEnabled(True)
434            self.cmdLoad.setText('Load')
435            # reset option for calculation
436            self.cbOptionsCalc.setCurrentIndex(0)
437            self.cbOptionsCalc.setVisible(False)
438            # reset shape button
439            self.cbShape.setCurrentIndex(0)
440            self.cbShape.setEnabled(True)
441            # reset compute button
442            self.cmdCompute.setText('Compute')
443            self.cmdCompute.setEnabled(True)
444            # TODO reload default data set
445            self._create_default_sld_data()
446
447        finally:
448            pass
449
450    def _create_default_2d_data(self):
451        """
452        Copied from previous version
453        Create 2D data by default
454        :warning: This data is never plotted.
455        """
456        self.qmax_x = float(self.txtQxMax.text())
457        self.npts_x = int(self.txtNoQBins.text())
458        self.data = Data2D()
459        self.data.is_data = False
460        # # Default values
461        self.data.detector.append(Detector())
462        index = len(self.data.detector) - 1
463        self.data.detector[index].distance = 8000  # mm
464        self.data.source.wavelength = 6  # A
465        self.data.detector[index].pixel_size.x = 5  # mm
466        self.data.detector[index].pixel_size.y = 5  # mm
467        self.data.detector[index].beam_center.x = self.qmax_x
468        self.data.detector[index].beam_center.y = self.qmax_x
469        xmax = self.qmax_x
470        xmin = -xmax
471        ymax = self.qmax_x
472        ymin = -ymax
473        qstep = self.npts_x
474
475        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
476        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
477        # use data info instead
478        new_x = numpy.tile(x, (len(y), 1))
479        new_y = numpy.tile(y, (len(x), 1))
480        new_y = new_y.swapaxes(0, 1)
481        # all data require now in 1d array
482        qx_data = new_x.flatten()
483        qy_data = new_y.flatten()
484        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
485        # set all True (standing for unmasked) as default
486        mask = numpy.ones(len(qx_data), dtype=bool)
487        self.data.source = Source()
488        self.data.data = numpy.ones(len(mask))
489        self.data.err_data = numpy.ones(len(mask))
490        self.data.qx_data = qx_data
491        self.data.qy_data = qy_data
492        self.data.q_data = q_data
493        self.data.mask = mask
494        # store x and y bin centers in q space
495        self.data.x_bins = x
496        self.data.y_bins = y
497        # max and min taking account of the bin sizes
498        self.data.xmin = xmin
499        self.data.xmax = xmax
500        self.data.ymin = ymin
501        self.data.ymax = ymax
502
503    def _create_default_sld_data(self):
504        """
505        Copied from previous version
506        Making default sld-data
507        """
508        sld_n_default = 6.97e-06  # what is this number??
509        omfdata = sas_gen.OMFData()
510        omf2sld = sas_gen.OMF2SLD()
511        omf2sld.set_data(omfdata, self.default_shape)
512        self.sld_data = omf2sld.output
513        self.sld_data.is_data = False
514        self.sld_data.filename = "Default SLD Profile"
515        self.sld_data.set_sldn(sld_n_default)
516
517    def _create_default_1d_data(self):
518        """
519        Copied from previous version
520        Create 1D data by default
521        :warning: This data is never plotted.
522                    residuals.x = data_copy.x[index]
523            residuals.dy = numpy.ones(len(residuals.y))
524            residuals.dx = None
525            residuals.dxl = None
526            residuals.dxw = None
527        """
528        self.qmax_x = float(self.txtQxMax.text())
529        self.npts_x = int(self.txtNoQBins.text())
530        #  Default values
531        xmax = self.qmax_x
532        xmin = self.qmax_x * _Q1D_MIN
533        qstep = self.npts_x
534        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
535        # store x and y bin centers in q space
536        y = numpy.ones(len(x))
537        dy = numpy.zeros(len(x))
538        dx = numpy.zeros(len(x))
539        self.data = Data1D(x=x, y=y)
540        self.data.dx = dx
541        self.data.dy = dy
542
543    def onCompute(self):
544        """
545        Copied from previous version
546        Execute the computation of I(qx, qy)
547        """
548        # Set default data when nothing loaded yet
549        if self.sld_data is None:
550            self._create_default_sld_data()
551        try:
552            self.model.set_sld_data(self.sld_data)
553            self.write_new_values_from_gui()
554            if self.is_avg or self.is_avg is None:
555                self._create_default_1d_data()
556                i_out = numpy.zeros(len(self.data.y))
557                inputs = [self.data.x, [], i_out]
558            else:
559                self._create_default_2d_data()
560                i_out = numpy.zeros(len(self.data.data))
561                inputs = [self.data.qx_data, self.data.qy_data, i_out]
562            logging.info("Computation is in progress...")
563            self.cmdCompute.setText('Wait...')
564            self.cmdCompute.setEnabled(False)
565            d = threads.deferToThread(self.complete, inputs, self._update)
566            # Add deferred callback for call return
567            #d.addCallback(self.plot_1_2d)
568            d.addCallback(self.calculateComplete)
569            d.addErrback(self.calculateFailed)
570        except:
571            log_msg = "{}. stop".format(sys.exc_info()[1])
572            logging.info(log_msg)
573        return
574
575    def _update(self, value):
576        """
577        Copied from previous version
578        """
579        pass
580
581    def calculateFailed(self, reason):
582        """
583        """
584        print("Calculate Failed with:\n", reason)
585        pass
586
587    def calculateComplete(self, d):
588        """
589        Notify the main thread
590        """
591        self.calculationFinishedSignal.emit()
592
593    def complete(self, input, update=None):
594        """
595        Gen compute complete function
596        :Param input: input list [qx_data, qy_data, i_out]
597        """
598        out = numpy.empty(0)
599        for ind in range(len(input[0])):
600            if self.is_avg:
601                if ind % 1 == 0 and update is not None:
602                    # update()
603                    percentage = int(100.0 * float(ind) / len(input[0]))
604                    update(percentage)
605                    time.sleep(0.001)  # 0.1
606                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
607                outi = self.model.run(inputi)
608                out = numpy.append(out, outi)
609            else:
610                if ind % 50 == 0 and update is not None:
611                    percentage = int(100.0 * float(ind) / len(input[0]))
612                    update(percentage)
613                    time.sleep(0.001)
614                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
615                          input[2][ind:ind + 1]]
616                outi = self.model.runXY(inputi)
617                out = numpy.append(out, outi)
618        self.data_to_plot = out
619        logging.info('Gen computation completed.')
620        self.cmdCompute.setText('Compute')
621        self.cmdCompute.setEnabled(True)
622        return
623
624    def onSaveFile(self):
625        """Save data as .sld file"""
626        path = os.path.dirname(str(self.datafile))
627        default_name = os.path.join(path, 'sld_file')
628        kwargs = {
629            'parent': self,
630            'directory': default_name,
631            'filter': 'SLD file (*.sld)',
632            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
633        # Query user for filename.
634        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
635        filename = filename_tuple[0]
636        if filename:
637            try:
638                _, extension = os.path.splitext(filename)
639                if not extension:
640                    filename = '.'.join((filename, 'sld'))
641                sas_gen.SLDReader().write(filename, self.sld_data)
642            except:
643                raise
644
645    def plot3d(self, has_arrow=False):
646        """ Generate 3D plot in real space with or without arrows """
647        self.write_new_values_from_gui()
648        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
649                                                            self.file_name)
650        if has_arrow:
651            graph_title += ' - Magnetic Vector as Arrow'
652
653        plot3D = Plotter3D(self, graph_title)
654        plot3D.plot(self.sld_data, has_arrow=has_arrow)
655        plot3D.show()
656        self.graph_num += 1
657
658    def plot_1_2d(self):
659        """ Generate 1D or 2D plot, called in Compute"""
660        if self.is_avg or self.is_avg is None:
661            data = Data1D(x=self.data.x, y=self.data_to_plot)
662            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
663                                                    int(self.graph_num))
664            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
665            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
666
667            self.graph_num += 1
668        else:
669            numpy.nan_to_num(self.data_to_plot)
670            data = Data2D(image=self.data_to_plot,
671                          qx_data=self.data.qx_data,
672                          qy_data=self.data.qy_data,
673                          q_data=self.data.q_data,
674                          xmin=self.data.xmin, xmax=self.data.ymax,
675                          ymin=self.data.ymin, ymax=self.data.ymax,
676                          err_image=self.data.err_data)
677            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
678                                                    int(self.graph_num))
679            zeros = numpy.ones(data.data.size, dtype=bool)
680            data.mask = zeros
681
682            self.graph_num += 1
683            # TODO
684        new_item = GuiUtils.createModelItemWithPlot(data, name=data.title)
685        self.communicator.updateModelFromPerspectiveSignal.emit(new_item)
686        self.communicator.forcePlotDisplaySignal.emit([new_item, data])
687
688class Plotter3DWidget(PlotterBase):
689    """
690    3D Plot widget for use with a QDialog
691    """
692    def __init__(self, parent=None, manager=None):
693        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
694
695    @property
696    def data(self):
697        return self._data
698
699    @data.setter
700    def data(self, data=None):
701        """ data setter """
702        self._data = data
703
704    def plot(self, data=None, has_arrow=False):
705        """
706        Plot 3D self._data
707        """
708        if not data:
709            return
710        self.data = data
711        #assert(self._data)
712        # Prepare and show the plot
713        self.showPlot(data=self.data, has_arrow=has_arrow)
714
715    def showPlot(self, data, has_arrow=False):
716        """
717        Render and show the current data
718        """
719        # If we don't have any data, skip.
720        if data is None:
721            return
722        # This import takes forever - place it here so the main UI starts faster
723        from mpl_toolkits.mplot3d import Axes3D
724        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
725                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
726        marker = ','
727        m_size = 2
728
729        pos_x = data.pos_x
730        pos_y = data.pos_y
731        pos_z = data.pos_z
732        sld_mx = data.sld_mx
733        sld_my = data.sld_my
734        sld_mz = data.sld_mz
735        pix_symbol = data.pix_symbol
736        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
737                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
738        is_nonzero = sld_tot > 0.0
739        is_zero = sld_tot == 0.0
740
741        if data.pix_type == 'atom':
742            marker = 'o'
743            m_size = 3.5
744
745        self.figure.clear()
746        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
747        ax = Axes3D(self.figure)
748        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
749        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
750        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
751
752        # I. Plot null points
753        if is_zero.any():
754            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
755                           marker, c="y", alpha=0.5, markeredgecolor='y',
756                           markersize=m_size)
757            pos_x = pos_x[is_nonzero]
758            pos_y = pos_y[is_nonzero]
759            pos_z = pos_z[is_nonzero]
760            sld_mx = sld_mx[is_nonzero]
761            sld_my = sld_my[is_nonzero]
762            sld_mz = sld_mz[is_nonzero]
763            pix_symbol = data.pix_symbol[is_nonzero]
764        # II. Plot selective points in color
765        other_color = numpy.ones(len(pix_symbol), dtype='bool')
766        for key in list(color_dic.keys()):
767            chosen_color = pix_symbol == key
768            if numpy.any(chosen_color):
769                other_color = other_color & (chosen_color!=True)
770                color = color_dic[key]
771                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
772                         pos_y[chosen_color], marker, c=color, alpha=0.5,
773                         markeredgecolor=color, markersize=m_size, label=key)
774        # III. Plot All others
775        if numpy.any(other_color):
776            a_name = ''
777            if data.pix_type == 'atom':
778                # Get atom names not in the list
779                a_names = [symb for symb in pix_symbol \
780                           if symb not in list(color_dic.keys())]
781                a_name = a_names[0]
782                for name in a_names:
783                    new_name = ", " + name
784                    if a_name.count(name) == 0:
785                        a_name += new_name
786            # plot in black
787            im = ax.plot(pos_x[other_color], pos_z[other_color],
788                         pos_y[other_color], marker, c="k", alpha=0.5,
789                         markeredgecolor="k", markersize=m_size, label=a_name)
790        if data.pix_type == 'atom':
791            ax.legend(loc='upper left', prop={'size': 10})
792        # IV. Draws atomic bond with grey lines if any
793        if data.has_conect:
794            for ind in range(len(data.line_x)):
795                im = ax.plot(data.line_x[ind], data.line_z[ind],
796                             data.line_y[ind], '-', lw=0.6, c="grey",
797                             alpha=0.3)
798        # V. Draws magnetic vectors
799        if has_arrow and len(pos_x) > 0:
800            def _draw_arrow(input=None, update=None):
801                # import moved here for performance reasons
802                from sas.qtgui.Plotting.Arrow3D import Arrow3D
803                """
804                draw magnetic vectors w/arrow
805                """
806                max_mx = max(numpy.fabs(sld_mx))
807                max_my = max(numpy.fabs(sld_my))
808                max_mz = max(numpy.fabs(sld_mz))
809                max_m = max(max_mx, max_my, max_mz)
810                try:
811                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
812                except:
813                    max_step = 0
814                if max_step <= 0:
815                    max_step = 5
816                try:
817                    if max_m != 0:
818                        unit_x2 = sld_mx / max_m
819                        unit_y2 = sld_my / max_m
820                        unit_z2 = sld_mz / max_m
821                        # 0.8 is for avoiding the color becomes white=(1,1,1))
822                        color_x = numpy.fabs(unit_x2 * 0.8)
823                        color_y = numpy.fabs(unit_y2 * 0.8)
824                        color_z = numpy.fabs(unit_z2 * 0.8)
825                        x2 = pos_x + unit_x2 * max_step
826                        y2 = pos_y + unit_y2 * max_step
827                        z2 = pos_z + unit_z2 * max_step
828                        x_arrow = numpy.column_stack((pos_x, x2))
829                        y_arrow = numpy.column_stack((pos_y, y2))
830                        z_arrow = numpy.column_stack((pos_z, z2))
831                        colors = numpy.column_stack((color_x, color_y, color_z))
832                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
833                                        colors, mutation_scale=10, lw=1,
834                                        arrowstyle="->", alpha=0.5)
835                        ax.add_artist(arrows)
836                except:
837                    pass
838                log_msg = "Arrow Drawing completed.\n"
839                logging.info(log_msg)
840            log_msg = "Arrow Drawing is in progress..."
841            logging.info(log_msg)
842
843            # Defer the drawing of arrows to another thread
844            d = threads.deferToThread(_draw_arrow, ax)
845
846        self.figure.canvas.resizing = False
847        self.figure.canvas.draw()
848
849    def createContextMenu(self):
850        """
851        Define common context menu and associated actions for the MPL widget
852        """
853        return
854
855    def createContextMenuQuick(self):
856        """
857        Define context menu and associated actions for the quickplot MPL widget
858        """
859        return
860
861
862class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
863    def __init__(self, parent=None, graph_title=''):
864        self.graph_title = graph_title
865        QtWidgets.QDialog.__init__(self)
866        Plotter3DWidget.__init__(self, manager=parent)
867        self.setWindowTitle(self.graph_title)
868
Note: See TracBrowser for help on using the repository browser.