source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ b9ab979

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since b9ab979 was 30e0be0, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 6 years ago

Updates to the scattering calculator SASVIEW-1147

  • Property mode set to 100644
File size: 33.8 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34    calculationFinishedSignal = QtCore.pyqtSignal()
35
36    def __init__(self, parent=None):
37        super(GenericScatteringCalculator, self).__init__()
38        self.setupUi(self)
39        self.manager = parent
40        self.communicator = self.manager.communicator()
41        self.model = sas_gen.GenSAS()
42        self.omf_reader = sas_gen.OMFReader()
43        self.sld_reader = sas_gen.SLDReader()
44        self.pdb_reader = sas_gen.PDBReader()
45        self.reader = None
46        self.sld_data = None
47
48        self.parameters = []
49        self.data = None
50        self.datafile = None
51        self.file_name = ''
52        self.ext = None
53        self.default_shape = str(self.cbShape.currentText())
54        self.is_avg = False
55        self.data_to_plot = None
56        self.graph_num = 1  # index for name of graph
57
58        # combox box
59        self.cbOptionsCalc.setVisible(False)
60
61        # push buttons
62        self.cmdClose.clicked.connect(self.accept)
63        self.cmdHelp.clicked.connect(self.onHelp)
64
65        self.cmdLoad.clicked.connect(self.loadFile)
66        self.cmdCompute.clicked.connect(self.onCompute)
67        self.cmdReset.clicked.connect(self.onReset)
68        self.cmdSave.clicked.connect(self.onSaveFile)
69
70        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
71        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
72
73        # validators
74        # scale, volume and background must be positive
75        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
76        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
77                                                          self.txtScale))
78        self.txtBackground.setValidator(QtGui.QRegExpValidator(
79            validat_regex_pos, self.txtBackground))
80        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
81            validat_regex_pos, self.txtTotalVolume))
82
83        # fraction of spin up between 0 and 1
84        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
85        self.txtUpFracIn.setValidator(
86            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
87        self.txtUpFracOut.setValidator(
88            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
89
90        # 0 < Qmax <= 1000
91        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
92        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
93                                                          self.txtQxMax))
94        self.txtQxMax.textChanged.connect(self.check_value)
95
96        # 2 <= Qbin <= 1000
97        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
98                                                            self.txtNoQBins))
99        self.txtNoQBins.textChanged.connect(self.check_value)
100
101        # plots - 3D in real space
102        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
103
104        # plots - 3D in real space
105        self.calculationFinishedSignal.connect(self.plot_1_2d)
106
107        # TODO the option Ellipsoid has not been implemented
108        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
109
110        # New font to display angstrom symbol
111        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
112        self.lblUnitSolventSLD.setStyleSheet(new_font)
113        self.lblUnitVolume.setStyleSheet(new_font)
114        self.lbl5.setStyleSheet(new_font)
115        self.lblUnitMx.setStyleSheet(new_font)
116        self.lblUnitMy.setStyleSheet(new_font)
117        self.lblUnitMz.setStyleSheet(new_font)
118        self.lblUnitNucl.setStyleSheet(new_font)
119        self.lblUnitx.setStyleSheet(new_font)
120        self.lblUnity.setStyleSheet(new_font)
121        self.lblUnitz.setStyleSheet(new_font)
122
123    def selectedshapechange(self):
124        """
125        TODO Temporary solution to display information about option 'Ellipsoid'
126        """
127        print("The option Ellipsoid has not been implemented yet.")
128        self.communicator.statusBarUpdateSignal.emit(
129            "The option Ellipsoid has not been implemented yet.")
130
131    def loadFile(self):
132        """
133        Open menu to choose the datafile to load
134        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
135        """
136        try:
137            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
138                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
139                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
140                                          "OMF files (*.OMF *.omf);; "
141                                          "All files (*.*)")[0]
142            if self.datafile:
143                self.default_shape = str(self.cbShape.currentText())
144                self.file_name = os.path.basename(str(self.datafile))
145                self.ext = os.path.splitext(str(self.datafile))[1]
146                if self.ext in self.omf_reader.ext:
147                    loader = self.omf_reader
148                elif self.ext in self.sld_reader.ext:
149                    loader = self.sld_reader
150                elif self.ext in self.pdb_reader.ext:
151                    loader = self.pdb_reader
152                else:
153                    loader = None
154                # disable some entries depending on type of loaded file
155                # (according to documentation)
156                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
157                    self.txtUpFracIn.setEnabled(False)
158                    self.txtUpFracOut.setEnabled(False)
159                    self.txtUpTheta.setEnabled(False)
160
161                if self.reader is not None and self.reader.isrunning():
162                    self.reader.stop()
163                self.cmdLoad.setEnabled(False)
164                self.cmdLoad.setText('Loading...')
165                self.communicator.statusBarUpdateSignal.emit(
166                    "Loading File {}".format(os.path.basename(
167                        str(self.datafile))))
168                self.reader = GenReader(path=str(self.datafile), loader=loader,
169                                        completefn=self.complete_loading,
170                                        updatefn=self.load_update)
171                self.reader.queue()
172        except (RuntimeError, IOError):
173            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
174            logging.info(log_msg)
175            raise
176        return
177
178    def load_update(self):
179        """ Legacy function used in GenRead """
180        if self.reader.isrunning():
181            status_type = "progress"
182        else:
183            status_type = "stop"
184        logging.info(status_type)
185
186    def complete_loading(self, data=None):
187        """ Function used in GenRead"""
188        self.cbShape.setEnabled(False)
189        try:
190            is_pdbdata = False
191            self.txtData.setText(os.path.basename(str(self.datafile)))
192            self.is_avg = False
193            if self.ext in self.omf_reader.ext:
194                gen = sas_gen.OMF2SLD()
195                gen.set_data(data)
196                self.sld_data = gen.get_magsld()
197                self.check_units()
198            elif self.ext in self.sld_reader.ext:
199                self.sld_data = data
200            elif self.ext in self.pdb_reader.ext:
201                self.sld_data = data
202                is_pdbdata = True
203            # Display combobox of orientation only for pdb data
204            self.cbOptionsCalc.setVisible(is_pdbdata)
205            self.update_gui()
206        except IOError:
207            log_msg = "Loading Error: " \
208                      "This file format is not supported for GenSAS."
209            logging.info(log_msg)
210            raise
211        except ValueError:
212            log_msg = "Could not find any data"
213            logging.info(log_msg)
214            raise
215        logging.info("Load Complete")
216        self.cmdLoad.setEnabled(True)
217        self.cmdLoad.setText('Load')
218        self.trigger_plot_3d.emit()
219
220    def check_units(self):
221        """
222        Check if the units from the OMF file correspond to the default ones
223        displayed on the interface.
224        If not, modify the GUI with the correct unit
225        """
226        #  TODO: adopt the convention of font and symbol for the updated values
227        if sas_gen.OMFData().valueunit != 'A^(-2)':
228            value_unit = sas_gen.OMFData().valueunit
229            self.lbl_unitMx.setText(value_unit)
230            self.lbl_unitMy.setText(value_unit)
231            self.lbl_unitMz.setText(value_unit)
232            self.lbl_unitNucl.setText(value_unit)
233        if sas_gen.OMFData().meshunit != 'A':
234            mesh_unit = sas_gen.OMFData().meshunit
235            self.lbl_unitx.setText(mesh_unit)
236            self.lbl_unity.setText(mesh_unit)
237            self.lbl_unitz.setText(mesh_unit)
238            self.lbl_unitVolume.setText(mesh_unit+"^3")
239
240    def check_value(self):
241        """Check range of text edits for QMax and Number of Qbins """
242        text_edit = self.sender()
243        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
244        if text_edit.text():
245            value = float(str(text_edit.text()))
246            if text_edit == self.txtQxMax:
247                if value <= 0 or value > 1000:
248                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
249                else:
250                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
251            elif text_edit == self.txtNoQBins:
252                if value < 2 or value > 1000:
253                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
254                else:
255                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
256
257    def update_gui(self):
258        """ Update the interface with values from loaded data """
259        self.model.set_is_avg(self.is_avg)
260        self.model.set_sld_data(self.sld_data)
261        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
262
263        # add condition for activation of save button
264        self.cmdSave.setEnabled(True)
265
266        # activation of 3D plots' buttons (with and without arrows)
267        self.cmdDraw.setEnabled(self.sld_data is not None)
268        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
269
270        self.txtScale.setText(str(self.model.params['scale']))
271        self.txtBackground.setText(str(self.model.params['background']))
272        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
273
274        # Volume to write to interface: npts x volume of first pixel
275        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
276
277        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
278        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
279        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
280
281        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
282        self.txtNoPixels.setEnabled(False)
283
284        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
285                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
286                           'zstepsize']
287        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
288                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
289                           self.txtXstepsize, self.txtYstepsize,
290                           self.txtZstepsize]
291
292        # Fill right hand side of GUI
293        for indx, item in enumerate(list_parameters):
294            if getattr(self.sld_data, item) is None:
295                list_gui_button[indx].setText('NaN')
296            else:
297                value = getattr(self.sld_data, item)
298                if isinstance(value, numpy.ndarray):
299                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
300                else:
301                    item_for_gui = str(GuiUtils.formatNumber(value, True))
302                list_gui_button[indx].setText(item_for_gui)
303
304        # Enable / disable editing of right hand side of GUI
305        for indx, item in enumerate(list_parameters):
306            if indx < 4:
307                # this condition only applies to Mx,y,z and Nucl
308                value = getattr(self.sld_data, item)
309                enable = self.sld_data.pix_type == 'pixel' \
310                         and numpy.min(value) == numpy.max(value)
311            else:
312                enable = not self.sld_data.is_data
313            list_gui_button[indx].setEnabled(enable)
314
315    def write_new_values_from_gui(self):
316        """
317        update parameters using modified inputs from GUI
318        used before computing
319        """
320        if self.txtScale.isModified():
321            self.model.params['scale'] = float(self.txtScale.text())
322
323        if self.txtBackground.isModified():
324            self.model.params['background'] = float(self.txtBackground.text())
325
326        if self.txtSolventSLD.isModified():
327            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
328
329        # Different condition for total volume to get correct volume after
330        # applying set_sld_data in compute
331        if self.txtTotalVolume.isModified() \
332                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
333            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
334
335        if self.txtUpFracIn.isModified():
336            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
337
338        if self.txtUpFracOut.isModified():
339            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
340
341        if self.txtUpTheta.isModified():
342            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
343
344        if self.txtMx.isModified():
345            self.sld_data.sld_mx = float(self.txtMx.text())*\
346                                   numpy.ones(len(self.sld_data.sld_mx))
347
348        if self.txtMy.isModified():
349            self.sld_data.sld_my = float(self.txtMy.text())*\
350                                   numpy.ones(len(self.sld_data.sld_my))
351
352        if self.txtMz.isModified():
353            self.sld_data.sld_mz = float(self.txtMz.text())*\
354                                   numpy.ones(len(self.sld_data.sld_mz))
355
356        if self.txtNucl.isModified():
357            self.sld_data.sld_n = float(self.txtNucl.text())*\
358                                  numpy.ones(len(self.sld_data.sld_n))
359
360        if self.txtXnodes.isModified():
361            self.sld_data.xnodes = int(self.txtXnodes.text())
362
363        if self.txtYnodes.isModified():
364            self.sld_data.ynodes = int(self.txtYnodes.text())
365
366        if self.txtZnodes.isModified():
367            self.sld_data.znodes = int(self.txtZnodes.text())
368
369        if self.txtXstepsize.isModified():
370            self.sld_data.xstepsize = float(self.txtXstepsize.text())
371
372        if self.txtYstepsize.isModified():
373            self.sld_data.ystepsize = float(self.txtYstepsize.text())
374
375        if self.txtZstepsize.isModified():
376            self.sld_data.zstepsize = float(self.txtZstepsize.text())
377
378        if self.cbOptionsCalc.isVisible():
379            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
380
381    def onHelp(self):
382        """
383        Bring up the Generic Scattering calculator Documentation whenever
384        the HELP button is clicked.
385        Calls Documentation Window with the path of the location within the
386        documentation tree (after /doc/ ....".
387        """
388        location = "/user/qtgui/Calculators/sas_calculator_help.html"
389        self.manager.showHelp(location)
390
391    def onReset(self):
392        """ Reset the inputs of textEdit to default values """
393        try:
394            # reset values in textedits
395            self.txtUpFracIn.setText("1.0")
396            self.txtUpFracOut.setText("1.0")
397            self.txtUpTheta.setText("0.0")
398            self.txtBackground.setText("0.0")
399            self.txtScale.setText("1.0")
400            self.txtSolventSLD.setText("0.0")
401            self.txtTotalVolume.setText("216000.0")
402            self.txtNoQBins.setText("50")
403            self.txtQxMax.setText("0.3")
404            self.txtNoPixels.setText("1000")
405            self.txtMx.setText("0")
406            self.txtMy.setText("0")
407            self.txtMz.setText("0")
408            self.txtNucl.setText("6.97e-06")
409            self.txtXnodes.setText("10")
410            self.txtYnodes.setText("10")
411            self.txtZnodes.setText("10")
412            self.txtXstepsize.setText("6")
413            self.txtYstepsize.setText("6")
414            self.txtZstepsize.setText("6")
415            # reset Load button and textedit
416            self.txtData.setText('Default SLD Profile')
417            self.cmdLoad.setEnabled(True)
418            self.cmdLoad.setText('Load')
419            # reset option for calculation
420            self.cbOptionsCalc.setCurrentIndex(0)
421            self.cbOptionsCalc.setVisible(False)
422            # reset shape button
423            self.cbShape.setCurrentIndex(0)
424            self.cbShape.setEnabled(True)
425            # reset compute button
426            self.cmdCompute.setText('Compute')
427            self.cmdCompute.setEnabled(True)
428            # TODO reload default data set
429            self._create_default_sld_data()
430
431        finally:
432            pass
433
434    def _create_default_2d_data(self):
435        """
436        Copied from previous version
437        Create 2D data by default
438        :warning: This data is never plotted.
439        """
440        self.qmax_x = float(self.txtQxMax.text())
441        self.npts_x = int(self.txtNoQBins.text())
442        self.data = Data2D()
443        self.data.is_data = False
444        # # Default values
445        self.data.detector.append(Detector())
446        index = len(self.data.detector) - 1
447        self.data.detector[index].distance = 8000  # mm
448        self.data.source.wavelength = 6  # A
449        self.data.detector[index].pixel_size.x = 5  # mm
450        self.data.detector[index].pixel_size.y = 5  # mm
451        self.data.detector[index].beam_center.x = self.qmax_x
452        self.data.detector[index].beam_center.y = self.qmax_x
453        xmax = self.qmax_x
454        xmin = -xmax
455        ymax = self.qmax_x
456        ymin = -ymax
457        qstep = self.npts_x
458
459        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
460        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
461        # use data info instead
462        new_x = numpy.tile(x, (len(y), 1))
463        new_y = numpy.tile(y, (len(x), 1))
464        new_y = new_y.swapaxes(0, 1)
465        # all data require now in 1d array
466        qx_data = new_x.flatten()
467        qy_data = new_y.flatten()
468        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
469        # set all True (standing for unmasked) as default
470        mask = numpy.ones(len(qx_data), dtype=bool)
471        self.data.source = Source()
472        self.data.data = numpy.ones(len(mask))
473        self.data.err_data = numpy.ones(len(mask))
474        self.data.qx_data = qx_data
475        self.data.qy_data = qy_data
476        self.data.q_data = q_data
477        self.data.mask = mask
478        # store x and y bin centers in q space
479        self.data.x_bins = x
480        self.data.y_bins = y
481        # max and min taking account of the bin sizes
482        self.data.xmin = xmin
483        self.data.xmax = xmax
484        self.data.ymin = ymin
485        self.data.ymax = ymax
486
487    def _create_default_sld_data(self):
488        """
489        Copied from previous version
490        Making default sld-data
491        """
492        sld_n_default = 6.97e-06  # what is this number??
493        omfdata = sas_gen.OMFData()
494        omf2sld = sas_gen.OMF2SLD()
495        omf2sld.set_data(omfdata, self.default_shape)
496        self.sld_data = omf2sld.output
497        self.sld_data.is_data = False
498        self.sld_data.filename = "Default SLD Profile"
499        self.sld_data.set_sldn(sld_n_default)
500
501    def _create_default_1d_data(self):
502        """
503        Copied from previous version
504        Create 1D data by default
505        :warning: This data is never plotted.
506                    residuals.x = data_copy.x[index]
507            residuals.dy = numpy.ones(len(residuals.y))
508            residuals.dx = None
509            residuals.dxl = None
510            residuals.dxw = None
511        """
512        self.qmax_x = float(self.txtQxMax.text())
513        self.npts_x = int(self.txtNoQBins.text())
514        #  Default values
515        xmax = self.qmax_x
516        xmin = self.qmax_x * _Q1D_MIN
517        qstep = self.npts_x
518        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
519        # store x and y bin centers in q space
520        y = numpy.ones(len(x))
521        dy = numpy.zeros(len(x))
522        dx = numpy.zeros(len(x))
523        self.data = Data1D(x=x, y=y)
524        self.data.dx = dx
525        self.data.dy = dy
526
527    def onCompute(self):
528        """
529        Copied from previous version
530        Execute the computation of I(qx, qy)
531        """
532        # Set default data when nothing loaded yet
533        if self.sld_data is None:
534            self._create_default_sld_data()
535        try:
536            self.model.set_sld_data(self.sld_data)
537            self.write_new_values_from_gui()
538            if self.is_avg or self.is_avg is None:
539                self._create_default_1d_data()
540                i_out = numpy.zeros(len(self.data.y))
541                inputs = [self.data.x, [], i_out]
542            else:
543                self._create_default_2d_data()
544                i_out = numpy.zeros(len(self.data.data))
545                inputs = [self.data.qx_data, self.data.qy_data, i_out]
546            logging.info("Computation is in progress...")
547            self.cmdCompute.setText('Wait...')
548            self.cmdCompute.setEnabled(False)
549            d = threads.deferToThread(self.complete, inputs, self._update)
550            # Add deferred callback for call return
551            #d.addCallback(self.plot_1_2d)
552            d.addCallback(self.calculateComplete)
553            d.addErrback(self.calculateFailed)
554        except:
555            log_msg = "{}. stop".format(sys.exc_info()[1])
556            logging.info(log_msg)
557        return
558
559    def _update(self, value):
560        """
561        Copied from previous version
562        """
563        pass
564
565    def calculateFailed(self, reason):
566        """
567        """
568        print("Calculate Failed with:\n", reason)
569        pass
570
571    def calculateComplete(self, d):
572        """
573        Notify the main thread
574        """
575        self.calculationFinishedSignal.emit()
576
577    def complete(self, input, update=None):
578        """
579        Gen compute complete function
580        :Param input: input list [qx_data, qy_data, i_out]
581        """
582        out = numpy.empty(0)
583        for ind in range(len(input[0])):
584            if self.is_avg:
585                if ind % 1 == 0 and update is not None:
586                    # update()
587                    percentage = int(100.0 * float(ind) / len(input[0]))
588                    update(percentage)
589                    time.sleep(0.001)  # 0.1
590                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
591                outi = self.model.run(inputi)
592                out = numpy.append(out, outi)
593            else:
594                if ind % 50 == 0 and update is not None:
595                    percentage = int(100.0 * float(ind) / len(input[0]))
596                    update(percentage)
597                    time.sleep(0.001)
598                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
599                          input[2][ind:ind + 1]]
600                outi = self.model.runXY(inputi)
601                out = numpy.append(out, outi)
602        self.data_to_plot = out
603        logging.info('Gen computation completed.')
604        self.cmdCompute.setText('Compute')
605        self.cmdCompute.setEnabled(True)
606        return
607
608    def onSaveFile(self):
609        """Save data as .sld file"""
610        path = os.path.dirname(str(self.datafile))
611        default_name = os.path.join(path, 'sld_file')
612        kwargs = {
613            'parent': self,
614            'directory': default_name,
615            'filter': 'SLD file (*.sld)',
616            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
617        # Query user for filename.
618        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
619        filename = filename_tuple[0]
620        if filename:
621            try:
622                _, extension = os.path.splitext(filename)
623                if not extension:
624                    filename = '.'.join((filename, 'sld'))
625                sas_gen.SLDReader().write(filename, self.sld_data)
626            except:
627                raise
628
629    def plot3d(self, has_arrow=False):
630        """ Generate 3D plot in real space with or without arrows """
631        self.write_new_values_from_gui()
632        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
633                                                            self.file_name)
634        if has_arrow:
635            graph_title += ' - Magnetic Vector as Arrow'
636
637        plot3D = Plotter3D(self, graph_title)
638        plot3D.plot(self.sld_data, has_arrow=has_arrow)
639        plot3D.show()
640        self.graph_num += 1
641
642    def plot_1_2d(self):
643        """ Generate 1D or 2D plot, called in Compute"""
644        if self.is_avg or self.is_avg is None:
645            data = Data1D(x=self.data.x, y=self.data_to_plot)
646            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
647                                                    int(self.graph_num))
648            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
649            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
650            plot1D = Plotter(self, quickplot=True)
651            plot1D.plot(data)
652            plot1D.show()
653            self.graph_num += 1
654            # TODO
655            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
656            return plot1D
657        else:
658            numpy.nan_to_num(self.data_to_plot)
659            data = Data2D(image=self.data_to_plot,
660                          qx_data=self.data.qx_data,
661                          qy_data=self.data.qy_data,
662                          q_data=self.data.q_data,
663                          xmin=self.data.xmin, xmax=self.data.ymax,
664                          ymin=self.data.ymin, ymax=self.data.ymax,
665                          err_image=self.data.err_data)
666            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
667                                                    int(self.graph_num))
668            plot2D = Plotter2D(self, quickplot=True)
669            plot2D.plot(data)
670            plot2D.show()
671            self.graph_num += 1
672            # TODO
673            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
674            return plot2D
675
676
677class Plotter3DWidget(PlotterBase):
678    """
679    3D Plot widget for use with a QDialog
680    """
681    def __init__(self, parent=None, manager=None):
682        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
683
684    @property
685    def data(self):
686        return self._data
687
688    @data.setter
689    def data(self, data=None):
690        """ data setter """
691        self._data = data
692
693    def plot(self, data=None, has_arrow=False):
694        """
695        Plot 3D self._data
696        """
697        if not data:
698            return
699        self.data = data
700        #assert(self._data)
701        # Prepare and show the plot
702        self.showPlot(data=self.data, has_arrow=has_arrow)
703
704    def showPlot(self, data, has_arrow=False):
705        """
706        Render and show the current data
707        """
708        # If we don't have any data, skip.
709        if data is None:
710            return
711        # This import takes forever - place it here so the main UI starts faster
712        from mpl_toolkits.mplot3d import Axes3D
713        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
714                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
715        marker = ','
716        m_size = 2
717
718        pos_x = data.pos_x
719        pos_y = data.pos_y
720        pos_z = data.pos_z
721        sld_mx = data.sld_mx
722        sld_my = data.sld_my
723        sld_mz = data.sld_mz
724        pix_symbol = data.pix_symbol
725        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
726                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
727        is_nonzero = sld_tot > 0.0
728        is_zero = sld_tot == 0.0
729
730        if data.pix_type == 'atom':
731            marker = 'o'
732            m_size = 3.5
733
734        self.figure.clear()
735        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
736        ax = Axes3D(self.figure)
737        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
738        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
739        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
740
741        # I. Plot null points
742        if is_zero.any():
743            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
744                           marker, c="y", alpha=0.5, markeredgecolor='y',
745                           markersize=m_size)
746            pos_x = pos_x[is_nonzero]
747            pos_y = pos_y[is_nonzero]
748            pos_z = pos_z[is_nonzero]
749            sld_mx = sld_mx[is_nonzero]
750            sld_my = sld_my[is_nonzero]
751            sld_mz = sld_mz[is_nonzero]
752            pix_symbol = data.pix_symbol[is_nonzero]
753        # II. Plot selective points in color
754        other_color = numpy.ones(len(pix_symbol), dtype='bool')
755        for key in list(color_dic.keys()):
756            chosen_color = pix_symbol == key
757            if numpy.any(chosen_color):
758                other_color = other_color & (chosen_color!=True)
759                color = color_dic[key]
760                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
761                         pos_y[chosen_color], marker, c=color, alpha=0.5,
762                         markeredgecolor=color, markersize=m_size, label=key)
763        # III. Plot All others
764        if numpy.any(other_color):
765            a_name = ''
766            if data.pix_type == 'atom':
767                # Get atom names not in the list
768                a_names = [symb for symb in pix_symbol \
769                           if symb not in list(color_dic.keys())]
770                a_name = a_names[0]
771                for name in a_names:
772                    new_name = ", " + name
773                    if a_name.count(name) == 0:
774                        a_name += new_name
775            # plot in black
776            im = ax.plot(pos_x[other_color], pos_z[other_color],
777                         pos_y[other_color], marker, c="k", alpha=0.5,
778                         markeredgecolor="k", markersize=m_size, label=a_name)
779        if data.pix_type == 'atom':
780            ax.legend(loc='upper left', prop={'size': 10})
781        # IV. Draws atomic bond with grey lines if any
782        if data.has_conect:
783            for ind in range(len(data.line_x)):
784                im = ax.plot(data.line_x[ind], data.line_z[ind],
785                             data.line_y[ind], '-', lw=0.6, c="grey",
786                             alpha=0.3)
787        # V. Draws magnetic vectors
788        if has_arrow and len(pos_x) > 0:
789            def _draw_arrow(input=None, update=None):
790                # import moved here for performance reasons
791                from sas.qtgui.Plotting.Arrow3D import Arrow3D
792                """
793                draw magnetic vectors w/arrow
794                """
795                max_mx = max(numpy.fabs(sld_mx))
796                max_my = max(numpy.fabs(sld_my))
797                max_mz = max(numpy.fabs(sld_mz))
798                max_m = max(max_mx, max_my, max_mz)
799                try:
800                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
801                except:
802                    max_step = 0
803                if max_step <= 0:
804                    max_step = 5
805                try:
806                    if max_m != 0:
807                        unit_x2 = sld_mx / max_m
808                        unit_y2 = sld_my / max_m
809                        unit_z2 = sld_mz / max_m
810                        # 0.8 is for avoiding the color becomes white=(1,1,1))
811                        color_x = numpy.fabs(unit_x2 * 0.8)
812                        color_y = numpy.fabs(unit_y2 * 0.8)
813                        color_z = numpy.fabs(unit_z2 * 0.8)
814                        x2 = pos_x + unit_x2 * max_step
815                        y2 = pos_y + unit_y2 * max_step
816                        z2 = pos_z + unit_z2 * max_step
817                        x_arrow = numpy.column_stack((pos_x, x2))
818                        y_arrow = numpy.column_stack((pos_y, y2))
819                        z_arrow = numpy.column_stack((pos_z, z2))
820                        colors = numpy.column_stack((color_x, color_y, color_z))
821                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
822                                        colors, mutation_scale=10, lw=1,
823                                        arrowstyle="->", alpha=0.5)
824                        ax.add_artist(arrows)
825                except:
826                    pass
827                log_msg = "Arrow Drawing completed.\n"
828                logging.info(log_msg)
829            log_msg = "Arrow Drawing is in progress..."
830            logging.info(log_msg)
831
832            # Defer the drawing of arrows to another thread
833            d = threads.deferToThread(_draw_arrow, ax)
834
835        self.figure.canvas.resizing = False
836        self.figure.canvas.draw()
837
838    def createContextMenu(self):
839        """
840        Define common context menu and associated actions for the MPL widget
841        """
842        return
843
844    def createContextMenuQuick(self):
845        """
846        Define context menu and associated actions for the quickplot MPL widget
847        """
848        return
849
850
851class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
852    def __init__(self, parent=None, graph_title=''):
853        self.graph_title = graph_title
854        QtWidgets.QDialog.__init__(self)
855        Plotter3DWidget.__init__(self, manager=parent)
856        self.setWindowTitle(self.graph_title)
857
Note: See TracBrowser for help on using the repository browser.