source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 46ca1f4

Last change on this file since 46ca1f4 was aed0532, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Updated references to help files

  • Property mode set to 100644
File size: 33.2 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34
35    def __init__(self, parent=None):
36        super(GenericScatteringCalculator, self).__init__()
37        self.setupUi(self)
38        self.manager = parent
39        self.communicator = self.manager.communicator()
40        self.model = sas_gen.GenSAS()
41        self.omf_reader = sas_gen.OMFReader()
42        self.sld_reader = sas_gen.SLDReader()
43        self.pdb_reader = sas_gen.PDBReader()
44        self.reader = None
45        self.sld_data = None
46
47        self.parameters = []
48        self.data = None
49        self.datafile = None
50        self.file_name = ''
51        self.ext = None
52        self.default_shape = str(self.cbShape.currentText())
53        self.is_avg = False
54        self.data_to_plot = None
55        self.graph_num = 1  # index for name of graph
56
57        # combox box
58        self.cbOptionsCalc.setVisible(False)
59
60        # push buttons
61        self.cmdClose.clicked.connect(self.accept)
62        self.cmdHelp.clicked.connect(self.onHelp)
63
64        self.cmdLoad.clicked.connect(self.loadFile)
65        self.cmdCompute.clicked.connect(self.onCompute)
66        self.cmdReset.clicked.connect(self.onReset)
67        self.cmdSave.clicked.connect(self.onSaveFile)
68
69        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
70        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
71
72        # validators
73        # scale, volume and background must be positive
74        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
75        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
76                                                          self.txtScale))
77        self.txtBackground.setValidator(QtGui.QRegExpValidator(
78            validat_regex_pos, self.txtBackground))
79        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtTotalVolume))
81
82        # fraction of spin up between 0 and 1
83        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
84        self.txtUpFracIn.setValidator(
85            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
86        self.txtUpFracOut.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
88
89        # 0 < Qmax <= 1000
90        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
91        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
92                                                          self.txtQxMax))
93        self.txtQxMax.textChanged.connect(self.check_value)
94
95        # 2 <= Qbin <= 1000
96        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
97                                                            self.txtNoQBins))
98        self.txtNoQBins.textChanged.connect(self.check_value)
99
100        # plots - 3D in real space
101        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
102
103        # TODO the option Ellipsoid has not been implemented
104        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
105
106        # New font to display angstrom symbol
107        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
108        self.lblUnitSolventSLD.setStyleSheet(new_font)
109        self.lblUnitVolume.setStyleSheet(new_font)
110        self.lbl5.setStyleSheet(new_font)
111        self.lblUnitMx.setStyleSheet(new_font)
112        self.lblUnitMy.setStyleSheet(new_font)
113        self.lblUnitMz.setStyleSheet(new_font)
114        self.lblUnitNucl.setStyleSheet(new_font)
115        self.lblUnitx.setStyleSheet(new_font)
116        self.lblUnity.setStyleSheet(new_font)
117        self.lblUnitz.setStyleSheet(new_font)
118
119    def selectedshapechange(self):
120        """
121        TODO Temporary solution to display information about option 'Ellipsoid'
122        """
123        print("The option Ellipsoid has not been implemented yet.")
124        self.communicator.statusBarUpdateSignal.emit(
125            "The option Ellipsoid has not been implemented yet.")
126
127    def loadFile(self):
128        """
129        Open menu to choose the datafile to load
130        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
131        """
132        try:
133            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
134                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
135                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
136                                          "OMF files (*.OMF *.omf);; "
137                                          "All files (*.*)")[0]
138            if self.datafile:
139                self.default_shape = str(self.cbShape.currentText())
140                self.file_name = os.path.basename(str(self.datafile))
141                self.ext = os.path.splitext(str(self.datafile))[1]
142                if self.ext in self.omf_reader.ext:
143                    loader = self.omf_reader
144                elif self.ext in self.sld_reader.ext:
145                    loader = self.sld_reader
146                elif self.ext in self.pdb_reader.ext:
147                    loader = self.pdb_reader
148                else:
149                    loader = None
150                # disable some entries depending on type of loaded file
151                # (according to documentation)
152                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
153                    self.txtUpFracIn.setEnabled(False)
154                    self.txtUpFracOut.setEnabled(False)
155                    self.txtUpTheta.setEnabled(False)
156
157                if self.reader is not None and self.reader.isrunning():
158                    self.reader.stop()
159                self.cmdLoad.setEnabled(False)
160                self.cmdLoad.setText('Loading...')
161                self.communicator.statusBarUpdateSignal.emit(
162                    "Loading File {}".format(os.path.basename(
163                        str(self.datafile))))
164                self.reader = GenReader(path=str(self.datafile), loader=loader,
165                                        completefn=self.complete_loading,
166                                        updatefn=self.load_update)
167                self.reader.queue()
168        except (RuntimeError, IOError):
169            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
170            logging.info(log_msg)
171            raise
172        return
173
174    def load_update(self):
175        """ Legacy function used in GenRead """
176        if self.reader.isrunning():
177            status_type = "progress"
178        else:
179            status_type = "stop"
180        logging.info(status_type)
181
182    def complete_loading(self, data=None):
183        """ Function used in GenRead"""
184        self.cbShape.setEnabled(False)
185        try:
186            is_pdbdata = False
187            self.txtData.setText(os.path.basename(str(self.datafile)))
188            self.is_avg = False
189            if self.ext in self.omf_reader.ext:
190                gen = sas_gen.OMF2SLD()
191                gen.set_data(data)
192                self.sld_data = gen.get_magsld()
193                self.check_units()
194            elif self.ext in self.sld_reader.ext:
195                self.sld_data = data
196            elif self.ext in self.pdb_reader.ext:
197                self.sld_data = data
198                is_pdbdata = True
199            # Display combobox of orientation only for pdb data
200            self.cbOptionsCalc.setVisible(is_pdbdata)
201            self.update_gui()
202        except IOError:
203            log_msg = "Loading Error: " \
204                      "This file format is not supported for GenSAS."
205            logging.info(log_msg)
206            raise
207        except ValueError:
208            log_msg = "Could not find any data"
209            logging.info(log_msg)
210            raise
211        logging.info("Load Complete")
212        self.cmdLoad.setEnabled(True)
213        self.cmdLoad.setText('Load')
214        self.trigger_plot_3d.emit()
215
216    def check_units(self):
217        """
218        Check if the units from the OMF file correspond to the default ones
219        displayed on the interface.
220        If not, modify the GUI with the correct unit
221        """
222        #  TODO: adopt the convention of font and symbol for the updated values
223        if sas_gen.OMFData().valueunit != 'A^(-2)':
224            value_unit = sas_gen.OMFData().valueunit
225            self.lbl_unitMx.setText(value_unit)
226            self.lbl_unitMy.setText(value_unit)
227            self.lbl_unitMz.setText(value_unit)
228            self.lbl_unitNucl.setText(value_unit)
229        if sas_gen.OMFData().meshunit != 'A':
230            mesh_unit = sas_gen.OMFData().meshunit
231            self.lbl_unitx.setText(mesh_unit)
232            self.lbl_unity.setText(mesh_unit)
233            self.lbl_unitz.setText(mesh_unit)
234            self.lbl_unitVolume.setText(mesh_unit+"^3")
235
236    def check_value(self):
237        """Check range of text edits for QMax and Number of Qbins """
238        text_edit = self.sender()
239        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
240        if text_edit.text():
241            value = float(str(text_edit.text()))
242            if text_edit == self.txtQxMax:
243                if value <= 0 or value > 1000:
244                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
245                else:
246                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
247            elif text_edit == self.txtNoQBins:
248                if value < 2 or value > 1000:
249                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
250                else:
251                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
252
253    def update_gui(self):
254        """ Update the interface with values from loaded data """
255        self.model.set_is_avg(self.is_avg)
256        self.model.set_sld_data(self.sld_data)
257        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
258
259        # add condition for activation of save button
260        self.cmdSave.setEnabled(True)
261
262        # activation of 3D plots' buttons (with and without arrows)
263        self.cmdDraw.setEnabled(self.sld_data is not None)
264        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
265
266        self.txtScale.setText(str(self.model.params['scale']))
267        self.txtBackground.setText(str(self.model.params['background']))
268        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
269
270        # Volume to write to interface: npts x volume of first pixel
271        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
272
273        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
274        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
275        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
276
277        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
278        self.txtNoPixels.setEnabled(False)
279
280        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
281                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
282                           'zstepsize']
283        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
284                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
285                           self.txtXstepsize, self.txtYstepsize,
286                           self.txtZstepsize]
287
288        # Fill right hand side of GUI
289        for indx, item in enumerate(list_parameters):
290            if getattr(self.sld_data, item) is None:
291                list_gui_button[indx].setText('NaN')
292            else:
293                value = getattr(self.sld_data, item)
294                if isinstance(value, numpy.ndarray):
295                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
296                else:
297                    item_for_gui = str(GuiUtils.formatNumber(value, True))
298                list_gui_button[indx].setText(item_for_gui)
299
300        # Enable / disable editing of right hand side of GUI
301        for indx, item in enumerate(list_parameters):
302            if indx < 4:
303                # this condition only applies to Mx,y,z and Nucl
304                value = getattr(self.sld_data, item)
305                enable = self.sld_data.pix_type == 'pixel' \
306                         and numpy.min(value) == numpy.max(value)
307            else:
308                enable = not self.sld_data.is_data
309            list_gui_button[indx].setEnabled(enable)
310
311    def write_new_values_from_gui(self):
312        """
313        update parameters using modified inputs from GUI
314        used before computing
315        """
316        if self.txtScale.isModified():
317            self.model.params['scale'] = float(self.txtScale.text())
318
319        if self.txtBackground.isModified():
320            self.model.params['background'] = float(self.txtBackground.text())
321
322        if self.txtSolventSLD.isModified():
323            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
324
325        # Different condition for total volume to get correct volume after
326        # applying set_sld_data in compute
327        if self.txtTotalVolume.isModified() \
328                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
329            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
330
331        if self.txtUpFracIn.isModified():
332            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
333
334        if self.txtUpFracOut.isModified():
335            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
336
337        if self.txtUpTheta.isModified():
338            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
339
340        if self.txtMx.isModified():
341            self.sld_data.sld_mx = float(self.txtMx.text())*\
342                                   numpy.ones(len(self.sld_data.sld_mx))
343
344        if self.txtMy.isModified():
345            self.sld_data.sld_my = float(self.txtMy.text())*\
346                                   numpy.ones(len(self.sld_data.sld_my))
347
348        if self.txtMz.isModified():
349            self.sld_data.sld_mz = float(self.txtMz.text())*\
350                                   numpy.ones(len(self.sld_data.sld_mz))
351
352        if self.txtNucl.isModified():
353            self.sld_data.sld_n = float(self.txtNucl.text())*\
354                                  numpy.ones(len(self.sld_data.sld_n))
355
356        if self.txtXnodes.isModified():
357            self.sld_data.xnodes = int(self.txtXnodes.text())
358
359        if self.txtYnodes.isModified():
360            self.sld_data.ynodes = int(self.txtYnodes.text())
361
362        if self.txtZnodes.isModified():
363            self.sld_data.znodes = int(self.txtZnodes.text())
364
365        if self.txtXstepsize.isModified():
366            self.sld_data.xstepsize = float(self.txtXstepsize.text())
367
368        if self.txtYstepsize.isModified():
369            self.sld_data.ystepsize = float(self.txtYstepsize.text())
370
371        if self.txtZstepsize.isModified():
372            self.sld_data.zstepsize = float(self.txtZstepsize.text())
373
374        if self.cbOptionsCalc.isVisible():
375            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
376
377    def onHelp(self):
378        """
379        Bring up the Generic Scattering calculator Documentation whenever
380        the HELP button is clicked.
381        Calls Documentation Window with the path of the location within the
382        documentation tree (after /doc/ ....".
383        """
384        location = "/user/qtgui/Calculators/sas_calculator_help.html"
385        self.manager.showHelp(location)
386
387    def onReset(self):
388        """ Reset the inputs of textEdit to default values """
389        try:
390            # reset values in textedits
391            self.txtUpFracIn.setText("1.0")
392            self.txtUpFracOut.setText("1.0")
393            self.txtUpTheta.setText("0.0")
394            self.txtBackground.setText("0.0")
395            self.txtScale.setText("1.0")
396            self.txtSolventSLD.setText("0.0")
397            self.txtTotalVolume.setText("216000.0")
398            self.txtNoQBins.setText("50")
399            self.txtQxMax.setText("0.3")
400            self.txtNoPixels.setText("1000")
401            self.txtMx.setText("0")
402            self.txtMy.setText("0")
403            self.txtMz.setText("0")
404            self.txtNucl.setText("6.97e-06")
405            self.txtXnodes.setText("10")
406            self.txtYnodes.setText("10")
407            self.txtZnodes.setText("10")
408            self.txtXstepsize.setText("6")
409            self.txtYstepsize.setText("6")
410            self.txtZstepsize.setText("6")
411            # reset Load button and textedit
412            self.txtData.setText('Default SLD Profile')
413            self.cmdLoad.setEnabled(True)
414            self.cmdLoad.setText('Load')
415            # reset option for calculation
416            self.cbOptionsCalc.setCurrentIndex(0)
417            self.cbOptionsCalc.setVisible(False)
418            # reset shape button
419            self.cbShape.setCurrentIndex(0)
420            self.cbShape.setEnabled(True)
421            # reset compute button
422            self.cmdCompute.setText('Compute')
423            self.cmdCompute.setEnabled(True)
424            # TODO reload default data set
425            self._create_default_sld_data()
426
427        finally:
428            pass
429
430    def _create_default_2d_data(self):
431        """
432        Copied from previous version
433        Create 2D data by default
434        :warning: This data is never plotted.
435        """
436        self.qmax_x = float(self.txtQxMax.text())
437        self.npts_x = int(self.txtNoQBins.text())
438        self.data = Data2D()
439        self.data.is_data = False
440        # # Default values
441        self.data.detector.append(Detector())
442        index = len(self.data.detector) - 1
443        self.data.detector[index].distance = 8000  # mm
444        self.data.source.wavelength = 6  # A
445        self.data.detector[index].pixel_size.x = 5  # mm
446        self.data.detector[index].pixel_size.y = 5  # mm
447        self.data.detector[index].beam_center.x = self.qmax_x
448        self.data.detector[index].beam_center.y = self.qmax_x
449        xmax = self.qmax_x
450        xmin = -xmax
451        ymax = self.qmax_x
452        ymin = -ymax
453        qstep = self.npts_x
454
455        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
456        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
457        # use data info instead
458        new_x = numpy.tile(x, (len(y), 1))
459        new_y = numpy.tile(y, (len(x), 1))
460        new_y = new_y.swapaxes(0, 1)
461        # all data require now in 1d array
462        qx_data = new_x.flatten()
463        qy_data = new_y.flatten()
464        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
465        # set all True (standing for unmasked) as default
466        mask = numpy.ones(len(qx_data), dtype=bool)
467        self.data.source = Source()
468        self.data.data = numpy.ones(len(mask))
469        self.data.err_data = numpy.ones(len(mask))
470        self.data.qx_data = qx_data
471        self.data.qy_data = qy_data
472        self.data.q_data = q_data
473        self.data.mask = mask
474        # store x and y bin centers in q space
475        self.data.x_bins = x
476        self.data.y_bins = y
477        # max and min taking account of the bin sizes
478        self.data.xmin = xmin
479        self.data.xmax = xmax
480        self.data.ymin = ymin
481        self.data.ymax = ymax
482
483    def _create_default_sld_data(self):
484        """
485        Copied from previous version
486        Making default sld-data
487        """
488        sld_n_default = 6.97e-06  # what is this number??
489        omfdata = sas_gen.OMFData()
490        omf2sld = sas_gen.OMF2SLD()
491        omf2sld.set_data(omfdata, self.default_shape)
492        self.sld_data = omf2sld.output
493        self.sld_data.is_data = False
494        self.sld_data.filename = "Default SLD Profile"
495        self.sld_data.set_sldn(sld_n_default)
496
497    def _create_default_1d_data(self):
498        """
499        Copied from previous version
500        Create 1D data by default
501        :warning: This data is never plotted.
502                    residuals.x = data_copy.x[index]
503            residuals.dy = numpy.ones(len(residuals.y))
504            residuals.dx = None
505            residuals.dxl = None
506            residuals.dxw = None
507        """
508        self.qmax_x = float(self.txtQxMax.text())
509        self.npts_x = int(self.txtNoQBins.text())
510        #  Default values
511        xmax = self.qmax_x
512        xmin = self.qmax_x * _Q1D_MIN
513        qstep = self.npts_x
514        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
515        # store x and y bin centers in q space
516        y = numpy.ones(len(x))
517        dy = numpy.zeros(len(x))
518        dx = numpy.zeros(len(x))
519        self.data = Data1D(x=x, y=y)
520        self.data.dx = dx
521        self.data.dy = dy
522
523    def onCompute(self):
524        """
525        Copied from previous version
526        Execute the computation of I(qx, qy)
527        """
528        # Set default data when nothing loaded yet
529        if self.sld_data is None:
530            self._create_default_sld_data()
531        try:
532            self.model.set_sld_data(self.sld_data)
533            self.write_new_values_from_gui()
534            if self.is_avg or self.is_avg is None:
535                self._create_default_1d_data()
536                i_out = numpy.zeros(len(self.data.y))
537                inputs = [self.data.x, [], i_out]
538            else:
539                self._create_default_2d_data()
540                i_out = numpy.zeros(len(self.data.data))
541                inputs = [self.data.qx_data, self.data.qy_data, i_out]
542            logging.info("Computation is in progress...")
543            self.cmdCompute.setText('Wait...')
544            self.cmdCompute.setEnabled(False)
545            d = threads.deferToThread(self.complete, inputs, self._update)
546            # Add deferred callback for call return
547            d.addCallback(self.plot_1_2d)
548            d.addErrback(self.calculateFailed)
549        except:
550            log_msg = "{}. stop".format(sys.exc_info()[1])
551            logging.info(log_msg)
552        return
553
554    def _update(self, value):
555        """
556        Copied from previous version
557        """
558        pass
559
560    def calculateFailed(self, reason):
561        """
562        """
563        print("Calculate Failed with:\n", reason)
564        pass
565
566    def complete(self, input, update=None):
567        """
568        Gen compute complete function
569        :Param input: input list [qx_data, qy_data, i_out]
570        """
571        out = numpy.empty(0)
572        for ind in range(len(input[0])):
573            if self.is_avg:
574                if ind % 1 == 0 and update is not None:
575                    # update()
576                    percentage = int(100.0 * float(ind) / len(input[0]))
577                    update(percentage)
578                    time.sleep(0.001)  # 0.1
579                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
580                outi = self.model.run(inputi)
581                out = numpy.append(out, outi)
582            else:
583                if ind % 50 == 0 and update is not None:
584                    percentage = int(100.0 * float(ind) / len(input[0]))
585                    update(percentage)
586                    time.sleep(0.001)
587                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
588                          input[2][ind:ind + 1]]
589                outi = self.model.runXY(inputi)
590                out = numpy.append(out, outi)
591        self.data_to_plot = out
592        logging.info('Gen computation completed.')
593        self.cmdCompute.setText('Compute')
594        self.cmdCompute.setEnabled(True)
595        return
596
597    def onSaveFile(self):
598        """Save data as .sld file"""
599        path = os.path.dirname(str(self.datafile))
600        default_name = os.path.join(path, 'sld_file')
601        kwargs = {
602            'parent': self,
603            'directory': default_name,
604            'filter': 'SLD file (*.sld)',
605            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
606        # Query user for filename.
607        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
608        filename = filename_tuple[0]
609        if filename:
610            try:
611                _, extension = os.path.splitext(filename)
612                if not extension:
613                    filename = '.'.join((filename, 'sld'))
614                sas_gen.SLDReader().write(filename, self.sld_data)
615            except:
616                raise
617
618    def plot3d(self, has_arrow=False):
619        """ Generate 3D plot in real space with or without arrows """
620        self.write_new_values_from_gui()
621        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
622                                                            self.file_name)
623        if has_arrow:
624            graph_title += ' - Magnetic Vector as Arrow'
625
626        plot3D = Plotter3D(self, graph_title)
627        plot3D.plot(self.sld_data, has_arrow=has_arrow)
628        plot3D.show()
629        self.graph_num += 1
630
631    def plot_1_2d(self, d):
632        """ Generate 1D or 2D plot, called in Compute"""
633        if self.is_avg or self.is_avg is None:
634            data = Data1D(x=self.data.x, y=self.data_to_plot)
635            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
636                                                    int(self.graph_num))
637            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
638            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
639            plot1D = Plotter(self)
640            plot1D.plot(data)
641            plot1D.show()
642            self.graph_num += 1
643            # TODO
644            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
645            return plot1D
646        else:
647            numpy.nan_to_num(self.data_to_plot)
648            data = Data2D(image=self.data_to_plot,
649                          qx_data=self.data.qx_data,
650                          qy_data=self.data.qy_data,
651                          q_data=self.data.q_data,
652                          xmin=self.data.xmin, xmax=self.data.ymax,
653                          ymin=self.data.ymin, ymax=self.data.ymax,
654                          err_image=self.data.err_data)
655            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
656                                                    int(self.graph_num))
657            plot2D = Plotter2D(self)
658            plot2D.plot(data)
659            plot2D.show()
660            self.graph_num += 1
661            # TODO
662            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
663            return plot2D
664
665
666class Plotter3DWidget(PlotterBase):
667    """
668    3D Plot widget for use with a QDialog
669    """
670    def __init__(self, parent=None, manager=None):
671        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
672
673    @property
674    def data(self):
675        return self._data
676
677    @data.setter
678    def data(self, data=None):
679        """ data setter """
680        self._data = data
681
682    def plot(self, data=None, has_arrow=False):
683        """
684        Plot 3D self._data
685        """
686        if not data:
687            return
688        self.data = data
689        #assert(self._data)
690        # Prepare and show the plot
691        self.showPlot(data=self.data, has_arrow=has_arrow)
692
693    def showPlot(self, data, has_arrow=False):
694        """
695        Render and show the current data
696        """
697        # If we don't have any data, skip.
698        if data is None:
699            return
700        # This import takes forever - place it here so the main UI starts faster
701        from mpl_toolkits.mplot3d import Axes3D
702        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
703                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
704        marker = ','
705        m_size = 2
706
707        pos_x = data.pos_x
708        pos_y = data.pos_y
709        pos_z = data.pos_z
710        sld_mx = data.sld_mx
711        sld_my = data.sld_my
712        sld_mz = data.sld_mz
713        pix_symbol = data.pix_symbol
714        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
715                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
716        is_nonzero = sld_tot > 0.0
717        is_zero = sld_tot == 0.0
718
719        if data.pix_type == 'atom':
720            marker = 'o'
721            m_size = 3.5
722
723        self.figure.clear()
724        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
725        ax = Axes3D(self.figure)
726        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
727        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
728        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
729
730        # I. Plot null points
731        if is_zero.any():
732            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
733                           marker, c="y", alpha=0.5, markeredgecolor='y',
734                           markersize=m_size)
735            pos_x = pos_x[is_nonzero]
736            pos_y = pos_y[is_nonzero]
737            pos_z = pos_z[is_nonzero]
738            sld_mx = sld_mx[is_nonzero]
739            sld_my = sld_my[is_nonzero]
740            sld_mz = sld_mz[is_nonzero]
741            pix_symbol = data.pix_symbol[is_nonzero]
742        # II. Plot selective points in color
743        other_color = numpy.ones(len(pix_symbol), dtype='bool')
744        for key in list(color_dic.keys()):
745            chosen_color = pix_symbol == key
746            if numpy.any(chosen_color):
747                other_color = other_color & (chosen_color!=True)
748                color = color_dic[key]
749                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
750                         pos_y[chosen_color], marker, c=color, alpha=0.5,
751                         markeredgecolor=color, markersize=m_size, label=key)
752        # III. Plot All others
753        if numpy.any(other_color):
754            a_name = ''
755            if data.pix_type == 'atom':
756                # Get atom names not in the list
757                a_names = [symb for symb in pix_symbol \
758                           if symb not in list(color_dic.keys())]
759                a_name = a_names[0]
760                for name in a_names:
761                    new_name = ", " + name
762                    if a_name.count(name) == 0:
763                        a_name += new_name
764            # plot in black
765            im = ax.plot(pos_x[other_color], pos_z[other_color],
766                         pos_y[other_color], marker, c="k", alpha=0.5,
767                         markeredgecolor="k", markersize=m_size, label=a_name)
768        if data.pix_type == 'atom':
769            ax.legend(loc='upper left', prop={'size': 10})
770        # IV. Draws atomic bond with grey lines if any
771        if data.has_conect:
772            for ind in range(len(data.line_x)):
773                im = ax.plot(data.line_x[ind], data.line_z[ind],
774                             data.line_y[ind], '-', lw=0.6, c="grey",
775                             alpha=0.3)
776        # V. Draws magnetic vectors
777        if has_arrow and len(pos_x) > 0:
778            def _draw_arrow(input=None, update=None):
779                # import moved here for performance reasons
780                from sas.qtgui.Plotting.Arrow3D import Arrow3D
781                """
782                draw magnetic vectors w/arrow
783                """
784                max_mx = max(numpy.fabs(sld_mx))
785                max_my = max(numpy.fabs(sld_my))
786                max_mz = max(numpy.fabs(sld_mz))
787                max_m = max(max_mx, max_my, max_mz)
788                try:
789                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
790                except:
791                    max_step = 0
792                if max_step <= 0:
793                    max_step = 5
794                try:
795                    if max_m != 0:
796                        unit_x2 = sld_mx / max_m
797                        unit_y2 = sld_my / max_m
798                        unit_z2 = sld_mz / max_m
799                        # 0.8 is for avoiding the color becomes white=(1,1,1))
800                        color_x = numpy.fabs(unit_x2 * 0.8)
801                        color_y = numpy.fabs(unit_y2 * 0.8)
802                        color_z = numpy.fabs(unit_z2 * 0.8)
803                        x2 = pos_x + unit_x2 * max_step
804                        y2 = pos_y + unit_y2 * max_step
805                        z2 = pos_z + unit_z2 * max_step
806                        x_arrow = numpy.column_stack((pos_x, x2))
807                        y_arrow = numpy.column_stack((pos_y, y2))
808                        z_arrow = numpy.column_stack((pos_z, z2))
809                        colors = numpy.column_stack((color_x, color_y, color_z))
810                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
811                                        colors, mutation_scale=10, lw=1,
812                                        arrowstyle="->", alpha=0.5)
813                        ax.add_artist(arrows)
814                except:
815                    pass
816                log_msg = "Arrow Drawing completed.\n"
817                logging.info(log_msg)
818            log_msg = "Arrow Drawing is in progress..."
819            logging.info(log_msg)
820
821            # Defer the drawing of arrows to another thread
822            d = threads.deferToThread(_draw_arrow, ax)
823
824        self.figure.canvas.resizing = False
825        self.figure.canvas.draw()
826
827
828class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
829    def __init__(self, parent=None, graph_title=''):
830        self.graph_title = graph_title
831        QtWidgets.QDialog.__init__(self)
832        Plotter3DWidget.__init__(self, manager=parent)
833        self.setWindowTitle(self.graph_title)
834
Note: See TracBrowser for help on using the repository browser.