source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ ccdee50

ESS_GUIESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_openclESS_GUI_sync_sascalc
Last change on this file since ccdee50 was 8c85ac1, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 6 years ago

Fixed plot generation and handling in the generic scattering calc.
SASVIEW-1216

  • Property mode set to 100644
File size: 34.2 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34    calculationFinishedSignal = QtCore.pyqtSignal()
35    loadingFinishedSignal = QtCore.pyqtSignal(list)
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        self.manager = parent
41        self.communicator = self.manager.communicator()
42        self.model = sas_gen.GenSAS()
43        self.omf_reader = sas_gen.OMFReader()
44        self.sld_reader = sas_gen.SLDReader()
45        self.pdb_reader = sas_gen.PDBReader()
46        self.reader = None
47        self.sld_data = None
48
49        self.parameters = []
50        self.data = None
51        self.datafile = None
52        self.file_name = ''
53        self.ext = None
54        self.default_shape = str(self.cbShape.currentText())
55        self.is_avg = False
56        self.data_to_plot = None
57        self.graph_num = 1  # index for name of graph
58
59        # combox box
60        self.cbOptionsCalc.setVisible(False)
61
62        # push buttons
63        self.cmdClose.clicked.connect(self.accept)
64        self.cmdHelp.clicked.connect(self.onHelp)
65
66        self.cmdLoad.clicked.connect(self.loadFile)
67        self.cmdCompute.clicked.connect(self.onCompute)
68        self.cmdReset.clicked.connect(self.onReset)
69        self.cmdSave.clicked.connect(self.onSaveFile)
70
71        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
72        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
73
74        # validators
75        # scale, volume and background must be positive
76        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
77        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
78                                                          self.txtScale))
79        self.txtBackground.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtBackground))
81        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
82            validat_regex_pos, self.txtTotalVolume))
83
84        # fraction of spin up between 0 and 1
85        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
86        self.txtUpFracIn.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
88        self.txtUpFracOut.setValidator(
89            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
90
91        # 0 < Qmax <= 1000
92        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
93        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
94                                                          self.txtQxMax))
95        self.txtQxMax.textChanged.connect(self.check_value)
96
97        # 2 <= Qbin <= 1000
98        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
99                                                            self.txtNoQBins))
100        self.txtNoQBins.textChanged.connect(self.check_value)
101
102        # plots - 3D in real space
103        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
104
105        # plots - 3D in real space
106        self.calculationFinishedSignal.connect(self.plot_1_2d)
107
108        # notify main thread about file load complete
109        self.loadingFinishedSignal.connect(self.complete_loading)
110
111        # TODO the option Ellipsoid has not been implemented
112        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
113
114        # New font to display angstrom symbol
115        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
116        self.lblUnitSolventSLD.setStyleSheet(new_font)
117        self.lblUnitVolume.setStyleSheet(new_font)
118        self.lbl5.setStyleSheet(new_font)
119        self.lblUnitMx.setStyleSheet(new_font)
120        self.lblUnitMy.setStyleSheet(new_font)
121        self.lblUnitMz.setStyleSheet(new_font)
122        self.lblUnitNucl.setStyleSheet(new_font)
123        self.lblUnitx.setStyleSheet(new_font)
124        self.lblUnity.setStyleSheet(new_font)
125        self.lblUnitz.setStyleSheet(new_font)
126
127    def selectedshapechange(self):
128        """
129        TODO Temporary solution to display information about option 'Ellipsoid'
130        """
131        print("The option Ellipsoid has not been implemented yet.")
132        self.communicator.statusBarUpdateSignal.emit(
133            "The option Ellipsoid has not been implemented yet.")
134
135    def loadFile(self):
136        """
137        Open menu to choose the datafile to load
138        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
139        """
140        try:
141            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
142                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
143                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
144                                          "OMF files (*.OMF *.omf);; "
145                                          "All files (*.*)")[0]
146            if self.datafile:
147                self.default_shape = str(self.cbShape.currentText())
148                self.file_name = os.path.basename(str(self.datafile))
149                self.ext = os.path.splitext(str(self.datafile))[1]
150                if self.ext in self.omf_reader.ext:
151                    loader = self.omf_reader
152                elif self.ext in self.sld_reader.ext:
153                    loader = self.sld_reader
154                elif self.ext in self.pdb_reader.ext:
155                    loader = self.pdb_reader
156                else:
157                    loader = None
158                # disable some entries depending on type of loaded file
159                # (according to documentation)
160                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
161                    self.txtUpFracIn.setEnabled(False)
162                    self.txtUpFracOut.setEnabled(False)
163                    self.txtUpTheta.setEnabled(False)
164
165                if self.reader is not None and self.reader.isrunning():
166                    self.reader.stop()
167                self.cmdLoad.setEnabled(False)
168                self.cmdLoad.setText('Loading...')
169                self.communicator.statusBarUpdateSignal.emit(
170                    "Loading File {}".format(os.path.basename(
171                        str(self.datafile))))
172                self.reader = GenReader(path=str(self.datafile), loader=loader,
173                                        completefn=self.complete_loading_ex,
174                                        updatefn=self.load_update)
175                self.reader.queue()
176        except (RuntimeError, IOError):
177            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
178            logging.info(log_msg)
179            raise
180        return
181
182    def load_update(self):
183        """ Legacy function used in GenRead """
184        if self.reader.isrunning():
185            status_type = "progress"
186        else:
187            status_type = "stop"
188        logging.info(status_type)
189
190    def complete_loading_ex(self, data=None):
191        """
192        Send the finish message from calculate threads to main thread
193        """
194        self.loadingFinishedSignal.emit(data)
195
196    def complete_loading(self, data=None):
197        """ Function used in GenRead"""
198        assert isinstance(data, list)
199        assert len(data)==1
200        data = data[0]
201        self.cbShape.setEnabled(False)
202        try:
203            is_pdbdata = False
204            self.txtData.setText(os.path.basename(str(self.datafile)))
205            self.is_avg = False
206            if self.ext in self.omf_reader.ext:
207                gen = sas_gen.OMF2SLD()
208                gen.set_data(data)
209                self.sld_data = gen.get_magsld()
210                self.check_units()
211            elif self.ext in self.sld_reader.ext:
212                self.sld_data = data
213            elif self.ext in self.pdb_reader.ext:
214                self.sld_data = data
215                is_pdbdata = True
216            # Display combobox of orientation only for pdb data
217            self.cbOptionsCalc.setVisible(is_pdbdata)
218            self.update_gui()
219        except IOError:
220            log_msg = "Loading Error: " \
221                      "This file format is not supported for GenSAS."
222            logging.info(log_msg)
223            raise
224        except ValueError:
225            log_msg = "Could not find any data"
226            logging.info(log_msg)
227            raise
228        logging.info("Load Complete")
229        self.cmdLoad.setEnabled(True)
230        self.cmdLoad.setText('Load')
231        self.trigger_plot_3d.emit()
232
233    def check_units(self):
234        """
235        Check if the units from the OMF file correspond to the default ones
236        displayed on the interface.
237        If not, modify the GUI with the correct unit
238        """
239        #  TODO: adopt the convention of font and symbol for the updated values
240        if sas_gen.OMFData().valueunit != 'A^(-2)':
241            value_unit = sas_gen.OMFData().valueunit
242            self.lbl_unitMx.setText(value_unit)
243            self.lbl_unitMy.setText(value_unit)
244            self.lbl_unitMz.setText(value_unit)
245            self.lbl_unitNucl.setText(value_unit)
246        if sas_gen.OMFData().meshunit != 'A':
247            mesh_unit = sas_gen.OMFData().meshunit
248            self.lbl_unitx.setText(mesh_unit)
249            self.lbl_unity.setText(mesh_unit)
250            self.lbl_unitz.setText(mesh_unit)
251            self.lbl_unitVolume.setText(mesh_unit+"^3")
252
253    def check_value(self):
254        """Check range of text edits for QMax and Number of Qbins """
255        text_edit = self.sender()
256        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
257        if text_edit.text():
258            value = float(str(text_edit.text()))
259            if text_edit == self.txtQxMax:
260                if value <= 0 or value > 1000:
261                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
262                else:
263                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
264            elif text_edit == self.txtNoQBins:
265                if value < 2 or value > 1000:
266                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
267                else:
268                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
269
270    def update_gui(self):
271        """ Update the interface with values from loaded data """
272        self.model.set_is_avg(self.is_avg)
273        self.model.set_sld_data(self.sld_data)
274        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
275
276        # add condition for activation of save button
277        self.cmdSave.setEnabled(True)
278
279        # activation of 3D plots' buttons (with and without arrows)
280        self.cmdDraw.setEnabled(self.sld_data is not None)
281        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
282
283        self.txtScale.setText(str(self.model.params['scale']))
284        self.txtBackground.setText(str(self.model.params['background']))
285        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
286
287        # Volume to write to interface: npts x volume of first pixel
288        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
289
290        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
291        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
292        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
293
294        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
295        self.txtNoPixels.setEnabled(False)
296
297        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
298                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
299                           'zstepsize']
300        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
301                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
302                           self.txtXstepsize, self.txtYstepsize,
303                           self.txtZstepsize]
304
305        # Fill right hand side of GUI
306        for indx, item in enumerate(list_parameters):
307            if getattr(self.sld_data, item) is None:
308                list_gui_button[indx].setText('NaN')
309            else:
310                value = getattr(self.sld_data, item)
311                if isinstance(value, numpy.ndarray):
312                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
313                else:
314                    item_for_gui = str(GuiUtils.formatNumber(value, True))
315                list_gui_button[indx].setText(item_for_gui)
316
317        # Enable / disable editing of right hand side of GUI
318        for indx, item in enumerate(list_parameters):
319            if indx < 4:
320                # this condition only applies to Mx,y,z and Nucl
321                value = getattr(self.sld_data, item)
322                enable = self.sld_data.pix_type == 'pixel' \
323                         and numpy.min(value) == numpy.max(value)
324            else:
325                enable = not self.sld_data.is_data
326            list_gui_button[indx].setEnabled(enable)
327
328    def write_new_values_from_gui(self):
329        """
330        update parameters using modified inputs from GUI
331        used before computing
332        """
333        if self.txtScale.isModified():
334            self.model.params['scale'] = float(self.txtScale.text())
335
336        if self.txtBackground.isModified():
337            self.model.params['background'] = float(self.txtBackground.text())
338
339        if self.txtSolventSLD.isModified():
340            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
341
342        # Different condition for total volume to get correct volume after
343        # applying set_sld_data in compute
344        if self.txtTotalVolume.isModified() \
345                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
346            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
347
348        if self.txtUpFracIn.isModified():
349            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
350
351        if self.txtUpFracOut.isModified():
352            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
353
354        if self.txtUpTheta.isModified():
355            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
356
357        if self.txtMx.isModified():
358            self.sld_data.sld_mx = float(self.txtMx.text())*\
359                                   numpy.ones(len(self.sld_data.sld_mx))
360
361        if self.txtMy.isModified():
362            self.sld_data.sld_my = float(self.txtMy.text())*\
363                                   numpy.ones(len(self.sld_data.sld_my))
364
365        if self.txtMz.isModified():
366            self.sld_data.sld_mz = float(self.txtMz.text())*\
367                                   numpy.ones(len(self.sld_data.sld_mz))
368
369        if self.txtNucl.isModified():
370            self.sld_data.sld_n = float(self.txtNucl.text())*\
371                                  numpy.ones(len(self.sld_data.sld_n))
372
373        if self.txtXnodes.isModified():
374            self.sld_data.xnodes = int(self.txtXnodes.text())
375
376        if self.txtYnodes.isModified():
377            self.sld_data.ynodes = int(self.txtYnodes.text())
378
379        if self.txtZnodes.isModified():
380            self.sld_data.znodes = int(self.txtZnodes.text())
381
382        if self.txtXstepsize.isModified():
383            self.sld_data.xstepsize = float(self.txtXstepsize.text())
384
385        if self.txtYstepsize.isModified():
386            self.sld_data.ystepsize = float(self.txtYstepsize.text())
387
388        if self.txtZstepsize.isModified():
389            self.sld_data.zstepsize = float(self.txtZstepsize.text())
390
391        if self.cbOptionsCalc.isVisible():
392            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
393
394    def onHelp(self):
395        """
396        Bring up the Generic Scattering calculator Documentation whenever
397        the HELP button is clicked.
398        Calls Documentation Window with the path of the location within the
399        documentation tree (after /doc/ ....".
400        """
401        location = "/user/qtgui/Calculators/sas_calculator_help.html"
402        self.manager.showHelp(location)
403
404    def onReset(self):
405        """ Reset the inputs of textEdit to default values """
406        try:
407            # reset values in textedits
408            self.txtUpFracIn.setText("1.0")
409            self.txtUpFracOut.setText("1.0")
410            self.txtUpTheta.setText("0.0")
411            self.txtBackground.setText("0.0")
412            self.txtScale.setText("1.0")
413            self.txtSolventSLD.setText("0.0")
414            self.txtTotalVolume.setText("216000.0")
415            self.txtNoQBins.setText("50")
416            self.txtQxMax.setText("0.3")
417            self.txtNoPixels.setText("1000")
418            self.txtMx.setText("0")
419            self.txtMy.setText("0")
420            self.txtMz.setText("0")
421            self.txtNucl.setText("6.97e-06")
422            self.txtXnodes.setText("10")
423            self.txtYnodes.setText("10")
424            self.txtZnodes.setText("10")
425            self.txtXstepsize.setText("6")
426            self.txtYstepsize.setText("6")
427            self.txtZstepsize.setText("6")
428            # reset Load button and textedit
429            self.txtData.setText('Default SLD Profile')
430            self.cmdLoad.setEnabled(True)
431            self.cmdLoad.setText('Load')
432            # reset option for calculation
433            self.cbOptionsCalc.setCurrentIndex(0)
434            self.cbOptionsCalc.setVisible(False)
435            # reset shape button
436            self.cbShape.setCurrentIndex(0)
437            self.cbShape.setEnabled(True)
438            # reset compute button
439            self.cmdCompute.setText('Compute')
440            self.cmdCompute.setEnabled(True)
441            # TODO reload default data set
442            self._create_default_sld_data()
443
444        finally:
445            pass
446
447    def _create_default_2d_data(self):
448        """
449        Copied from previous version
450        Create 2D data by default
451        :warning: This data is never plotted.
452        """
453        self.qmax_x = float(self.txtQxMax.text())
454        self.npts_x = int(self.txtNoQBins.text())
455        self.data = Data2D()
456        self.data.is_data = False
457        # # Default values
458        self.data.detector.append(Detector())
459        index = len(self.data.detector) - 1
460        self.data.detector[index].distance = 8000  # mm
461        self.data.source.wavelength = 6  # A
462        self.data.detector[index].pixel_size.x = 5  # mm
463        self.data.detector[index].pixel_size.y = 5  # mm
464        self.data.detector[index].beam_center.x = self.qmax_x
465        self.data.detector[index].beam_center.y = self.qmax_x
466        xmax = self.qmax_x
467        xmin = -xmax
468        ymax = self.qmax_x
469        ymin = -ymax
470        qstep = self.npts_x
471
472        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
473        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
474        # use data info instead
475        new_x = numpy.tile(x, (len(y), 1))
476        new_y = numpy.tile(y, (len(x), 1))
477        new_y = new_y.swapaxes(0, 1)
478        # all data require now in 1d array
479        qx_data = new_x.flatten()
480        qy_data = new_y.flatten()
481        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
482        # set all True (standing for unmasked) as default
483        mask = numpy.ones(len(qx_data), dtype=bool)
484        self.data.source = Source()
485        self.data.data = numpy.ones(len(mask))
486        self.data.err_data = numpy.ones(len(mask))
487        self.data.qx_data = qx_data
488        self.data.qy_data = qy_data
489        self.data.q_data = q_data
490        self.data.mask = mask
491        # store x and y bin centers in q space
492        self.data.x_bins = x
493        self.data.y_bins = y
494        # max and min taking account of the bin sizes
495        self.data.xmin = xmin
496        self.data.xmax = xmax
497        self.data.ymin = ymin
498        self.data.ymax = ymax
499
500    def _create_default_sld_data(self):
501        """
502        Copied from previous version
503        Making default sld-data
504        """
505        sld_n_default = 6.97e-06  # what is this number??
506        omfdata = sas_gen.OMFData()
507        omf2sld = sas_gen.OMF2SLD()
508        omf2sld.set_data(omfdata, self.default_shape)
509        self.sld_data = omf2sld.output
510        self.sld_data.is_data = False
511        self.sld_data.filename = "Default SLD Profile"
512        self.sld_data.set_sldn(sld_n_default)
513
514    def _create_default_1d_data(self):
515        """
516        Copied from previous version
517        Create 1D data by default
518        :warning: This data is never plotted.
519                    residuals.x = data_copy.x[index]
520            residuals.dy = numpy.ones(len(residuals.y))
521            residuals.dx = None
522            residuals.dxl = None
523            residuals.dxw = None
524        """
525        self.qmax_x = float(self.txtQxMax.text())
526        self.npts_x = int(self.txtNoQBins.text())
527        #  Default values
528        xmax = self.qmax_x
529        xmin = self.qmax_x * _Q1D_MIN
530        qstep = self.npts_x
531        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
532        # store x and y bin centers in q space
533        y = numpy.ones(len(x))
534        dy = numpy.zeros(len(x))
535        dx = numpy.zeros(len(x))
536        self.data = Data1D(x=x, y=y)
537        self.data.dx = dx
538        self.data.dy = dy
539
540    def onCompute(self):
541        """
542        Copied from previous version
543        Execute the computation of I(qx, qy)
544        """
545        # Set default data when nothing loaded yet
546        if self.sld_data is None:
547            self._create_default_sld_data()
548        try:
549            self.model.set_sld_data(self.sld_data)
550            self.write_new_values_from_gui()
551            if self.is_avg or self.is_avg is None:
552                self._create_default_1d_data()
553                i_out = numpy.zeros(len(self.data.y))
554                inputs = [self.data.x, [], i_out]
555            else:
556                self._create_default_2d_data()
557                i_out = numpy.zeros(len(self.data.data))
558                inputs = [self.data.qx_data, self.data.qy_data, i_out]
559            logging.info("Computation is in progress...")
560            self.cmdCompute.setText('Wait...')
561            self.cmdCompute.setEnabled(False)
562            d = threads.deferToThread(self.complete, inputs, self._update)
563            # Add deferred callback for call return
564            #d.addCallback(self.plot_1_2d)
565            d.addCallback(self.calculateComplete)
566            d.addErrback(self.calculateFailed)
567        except:
568            log_msg = "{}. stop".format(sys.exc_info()[1])
569            logging.info(log_msg)
570        return
571
572    def _update(self, value):
573        """
574        Copied from previous version
575        """
576        pass
577
578    def calculateFailed(self, reason):
579        """
580        """
581        print("Calculate Failed with:\n", reason)
582        pass
583
584    def calculateComplete(self, d):
585        """
586        Notify the main thread
587        """
588        self.calculationFinishedSignal.emit()
589
590    def complete(self, input, update=None):
591        """
592        Gen compute complete function
593        :Param input: input list [qx_data, qy_data, i_out]
594        """
595        out = numpy.empty(0)
596        for ind in range(len(input[0])):
597            if self.is_avg:
598                if ind % 1 == 0 and update is not None:
599                    # update()
600                    percentage = int(100.0 * float(ind) / len(input[0]))
601                    update(percentage)
602                    time.sleep(0.001)  # 0.1
603                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
604                outi = self.model.run(inputi)
605                out = numpy.append(out, outi)
606            else:
607                if ind % 50 == 0 and update is not None:
608                    percentage = int(100.0 * float(ind) / len(input[0]))
609                    update(percentage)
610                    time.sleep(0.001)
611                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
612                          input[2][ind:ind + 1]]
613                outi = self.model.runXY(inputi)
614                out = numpy.append(out, outi)
615        self.data_to_plot = out
616        logging.info('Gen computation completed.')
617        self.cmdCompute.setText('Compute')
618        self.cmdCompute.setEnabled(True)
619        return
620
621    def onSaveFile(self):
622        """Save data as .sld file"""
623        path = os.path.dirname(str(self.datafile))
624        default_name = os.path.join(path, 'sld_file')
625        kwargs = {
626            'parent': self,
627            'directory': default_name,
628            'filter': 'SLD file (*.sld)',
629            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
630        # Query user for filename.
631        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
632        filename = filename_tuple[0]
633        if filename:
634            try:
635                _, extension = os.path.splitext(filename)
636                if not extension:
637                    filename = '.'.join((filename, 'sld'))
638                sas_gen.SLDReader().write(filename, self.sld_data)
639            except:
640                raise
641
642    def plot3d(self, has_arrow=False):
643        """ Generate 3D plot in real space with or without arrows """
644        self.write_new_values_from_gui()
645        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
646                                                            self.file_name)
647        if has_arrow:
648            graph_title += ' - Magnetic Vector as Arrow'
649
650        plot3D = Plotter3D(self, graph_title)
651        plot3D.plot(self.sld_data, has_arrow=has_arrow)
652        plot3D.show()
653        self.graph_num += 1
654
655    def plot_1_2d(self):
656        """ Generate 1D or 2D plot, called in Compute"""
657        if self.is_avg or self.is_avg is None:
658            data = Data1D(x=self.data.x, y=self.data_to_plot)
659            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
660                                                    int(self.graph_num))
661            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
662            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
663
664            self.graph_num += 1
665        else:
666            numpy.nan_to_num(self.data_to_plot)
667            data = Data2D(image=self.data_to_plot,
668                          qx_data=self.data.qx_data,
669                          qy_data=self.data.qy_data,
670                          q_data=self.data.q_data,
671                          xmin=self.data.xmin, xmax=self.data.ymax,
672                          ymin=self.data.ymin, ymax=self.data.ymax,
673                          err_image=self.data.err_data)
674            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
675                                                    int(self.graph_num))
676            zeros = numpy.ones(data.data.size, dtype=bool)
677            data.mask = zeros
678
679            self.graph_num += 1
680            # TODO
681        new_item = GuiUtils.createModelItemWithPlot(data, name=data.title)
682        self.communicator.updateModelFromPerspectiveSignal.emit(new_item)
683        self.communicator.forcePlotDisplaySignal.emit([new_item, data])
684
685class Plotter3DWidget(PlotterBase):
686    """
687    3D Plot widget for use with a QDialog
688    """
689    def __init__(self, parent=None, manager=None):
690        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
691
692    @property
693    def data(self):
694        return self._data
695
696    @data.setter
697    def data(self, data=None):
698        """ data setter """
699        self._data = data
700
701    def plot(self, data=None, has_arrow=False):
702        """
703        Plot 3D self._data
704        """
705        if not data:
706            return
707        self.data = data
708        #assert(self._data)
709        # Prepare and show the plot
710        self.showPlot(data=self.data, has_arrow=has_arrow)
711
712    def showPlot(self, data, has_arrow=False):
713        """
714        Render and show the current data
715        """
716        # If we don't have any data, skip.
717        if data is None:
718            return
719        # This import takes forever - place it here so the main UI starts faster
720        from mpl_toolkits.mplot3d import Axes3D
721        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
722                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
723        marker = ','
724        m_size = 2
725
726        pos_x = data.pos_x
727        pos_y = data.pos_y
728        pos_z = data.pos_z
729        sld_mx = data.sld_mx
730        sld_my = data.sld_my
731        sld_mz = data.sld_mz
732        pix_symbol = data.pix_symbol
733        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
734                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
735        is_nonzero = sld_tot > 0.0
736        is_zero = sld_tot == 0.0
737
738        if data.pix_type == 'atom':
739            marker = 'o'
740            m_size = 3.5
741
742        self.figure.clear()
743        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
744        ax = Axes3D(self.figure)
745        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
746        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
747        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
748
749        # I. Plot null points
750        if is_zero.any():
751            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
752                           marker, c="y", alpha=0.5, markeredgecolor='y',
753                           markersize=m_size)
754            pos_x = pos_x[is_nonzero]
755            pos_y = pos_y[is_nonzero]
756            pos_z = pos_z[is_nonzero]
757            sld_mx = sld_mx[is_nonzero]
758            sld_my = sld_my[is_nonzero]
759            sld_mz = sld_mz[is_nonzero]
760            pix_symbol = data.pix_symbol[is_nonzero]
761        # II. Plot selective points in color
762        other_color = numpy.ones(len(pix_symbol), dtype='bool')
763        for key in list(color_dic.keys()):
764            chosen_color = pix_symbol == key
765            if numpy.any(chosen_color):
766                other_color = other_color & (chosen_color!=True)
767                color = color_dic[key]
768                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
769                         pos_y[chosen_color], marker, c=color, alpha=0.5,
770                         markeredgecolor=color, markersize=m_size, label=key)
771        # III. Plot All others
772        if numpy.any(other_color):
773            a_name = ''
774            if data.pix_type == 'atom':
775                # Get atom names not in the list
776                a_names = [symb for symb in pix_symbol \
777                           if symb not in list(color_dic.keys())]
778                a_name = a_names[0]
779                for name in a_names:
780                    new_name = ", " + name
781                    if a_name.count(name) == 0:
782                        a_name += new_name
783            # plot in black
784            im = ax.plot(pos_x[other_color], pos_z[other_color],
785                         pos_y[other_color], marker, c="k", alpha=0.5,
786                         markeredgecolor="k", markersize=m_size, label=a_name)
787        if data.pix_type == 'atom':
788            ax.legend(loc='upper left', prop={'size': 10})
789        # IV. Draws atomic bond with grey lines if any
790        if data.has_conect:
791            for ind in range(len(data.line_x)):
792                im = ax.plot(data.line_x[ind], data.line_z[ind],
793                             data.line_y[ind], '-', lw=0.6, c="grey",
794                             alpha=0.3)
795        # V. Draws magnetic vectors
796        if has_arrow and len(pos_x) > 0:
797            def _draw_arrow(input=None, update=None):
798                # import moved here for performance reasons
799                from sas.qtgui.Plotting.Arrow3D import Arrow3D
800                """
801                draw magnetic vectors w/arrow
802                """
803                max_mx = max(numpy.fabs(sld_mx))
804                max_my = max(numpy.fabs(sld_my))
805                max_mz = max(numpy.fabs(sld_mz))
806                max_m = max(max_mx, max_my, max_mz)
807                try:
808                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
809                except:
810                    max_step = 0
811                if max_step <= 0:
812                    max_step = 5
813                try:
814                    if max_m != 0:
815                        unit_x2 = sld_mx / max_m
816                        unit_y2 = sld_my / max_m
817                        unit_z2 = sld_mz / max_m
818                        # 0.8 is for avoiding the color becomes white=(1,1,1))
819                        color_x = numpy.fabs(unit_x2 * 0.8)
820                        color_y = numpy.fabs(unit_y2 * 0.8)
821                        color_z = numpy.fabs(unit_z2 * 0.8)
822                        x2 = pos_x + unit_x2 * max_step
823                        y2 = pos_y + unit_y2 * max_step
824                        z2 = pos_z + unit_z2 * max_step
825                        x_arrow = numpy.column_stack((pos_x, x2))
826                        y_arrow = numpy.column_stack((pos_y, y2))
827                        z_arrow = numpy.column_stack((pos_z, z2))
828                        colors = numpy.column_stack((color_x, color_y, color_z))
829                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
830                                        colors, mutation_scale=10, lw=1,
831                                        arrowstyle="->", alpha=0.5)
832                        ax.add_artist(arrows)
833                except:
834                    pass
835                log_msg = "Arrow Drawing completed.\n"
836                logging.info(log_msg)
837            log_msg = "Arrow Drawing is in progress..."
838            logging.info(log_msg)
839
840            # Defer the drawing of arrows to another thread
841            d = threads.deferToThread(_draw_arrow, ax)
842
843        self.figure.canvas.resizing = False
844        self.figure.canvas.draw()
845
846    def createContextMenu(self):
847        """
848        Define common context menu and associated actions for the MPL widget
849        """
850        return
851
852    def createContextMenuQuick(self):
853        """
854        Define context menu and associated actions for the quickplot MPL widget
855        """
856        return
857
858
859class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
860    def __init__(self, parent=None, graph_title=''):
861        self.graph_title = graph_title
862        QtWidgets.QDialog.__init__(self)
863        Plotter3DWidget.__init__(self, manager=parent)
864        self.setWindowTitle(self.graph_title)
865
Note: See TracBrowser for help on using the repository browser.