source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ dc5ef15

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since dc5ef15 was dc5ef15, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Removed qtgui dependency on sasgui and wx SASVIEW-590

  • Property mode set to 100755
File size: 32.7 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10from mpl_toolkits.mplot3d import Axes3D
11
12import sas.qtgui.Utilities.GuiUtils as GuiUtils
13
14from sas.qtgui.Utilities.GenericReader import GenReader
15
16from sas.sascalc.dataloader.data_info import Detector
17from sas.sascalc.dataloader.data_info import Source
18from sas.sascalc.calculator import sas_gen
19
20from sas.qtgui.Plotting.Arrow3D import Arrow3D
21from sas.qtgui.Plotting.PlotterBase import PlotterBase
22from sas.qtgui.Plotting.Plotter2D import Plotter2D
23from sas.qtgui.Plotting.Plotter import Plotter
24from sas.qtgui.Plotting.PlotterData import Data1D
25from sas.qtgui.Plotting.PlotterData import Data2D
26
27# Local UI
28from UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
29
30_Q1D_MIN = 0.001
31
32
33class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
34
35    trigger_plot_3d = QtCore.pyqtSignal()
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        self.manager = parent
41        self.communicator = self.manager.communicator()
42        self.model = sas_gen.GenSAS()
43        self.omf_reader = sas_gen.OMFReader()
44        self.sld_reader = sas_gen.SLDReader()
45        self.pdb_reader = sas_gen.PDBReader()
46        self.reader = None
47        self.sld_data = None
48
49        self.parameters = []
50        self.data = None
51        self.datafile = None
52        self.file_name = ''
53        self.ext = None
54        self.default_shape = str(self.cbShape.currentText())
55        self.is_avg = False
56        self.data_to_plot = None
57
58        # combox box
59        self.cbOptionsCalc.setVisible(False)
60
61        # push buttons
62        self.cmdClose.clicked.connect(self.accept)
63        self.cmdHelp.clicked.connect(self.onHelp)
64
65        self.cmdLoad.clicked.connect(self.loadFile)
66        self.cmdCompute.clicked.connect(self.onCompute)
67        self.cmdReset.clicked.connect(self.onReset)
68        self.cmdSave.clicked.connect(self.onSaveFile)
69
70        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
71        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
72
73        # validators
74        # scale, volume and background must be positive
75        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
76        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
77                                                          self.txtScale))
78        self.txtBackground.setValidator(QtGui.QRegExpValidator(
79            validat_regex_pos, self.txtBackground))
80        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
81            validat_regex_pos, self.txtTotalVolume))
82
83        # fraction of spin up between 0 and 1
84        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
85        self.txtUpFracIn.setValidator(
86            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
87        self.txtUpFracOut.setValidator(
88            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
89
90        # 0 < Qmax <= 1000
91        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
92        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
93                                                          self.txtQxMax))
94        self.txtQxMax.textChanged.connect(self.check_value)
95
96        # 2 <= Qbin <= 1000
97        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
98                                                            self.txtNoQBins))
99        self.txtNoQBins.textChanged.connect(self.check_value)
100
101        # plots - 3D in real space
102        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
103
104        self.graph_num = 1  # index for name of graph
105
106        # TODO the option Ellipsoid has not been implemented
107        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
108
109    def selectedshapechange(self):
110        """
111        TODO Temporary solution to display information about option 'Ellipsoid'
112        """
113        print "The option Ellipsoid has not been implemented yet."
114        self.communicator.statusBarUpdateSignal.emit(
115            "The option Ellipsoid has not been implemented yet.")
116
117    def loadFile(self):
118        """
119        Open menu to choose the datafile to load
120        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
121        """
122        try:
123            self.datafile = QtGui.QFileDialog.getOpenFileName(
124                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
125                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
126                                          "OMF files (*.OMF *.omf);; "
127                                          "All files (*.*)")
128            if self.datafile:
129                self.default_shape = str(self.cbShape.currentText())
130                self.file_name = os.path.basename(str(self.datafile))
131                self.ext = os.path.splitext(str(self.datafile))[1]
132                if self.ext in self.omf_reader.ext:
133                    loader = self.omf_reader
134                elif self.ext in self.sld_reader.ext:
135                    loader = self.sld_reader
136                elif self.ext in self.pdb_reader.ext:
137                    loader = self.pdb_reader
138                else:
139                    loader = None
140                # disable some entries depending on type of loaded file
141                # (according to documentation)
142                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
143                    self.txtUpFracIn.setEnabled(False)
144                    self.txtUpFracOut.setEnabled(False)
145                    self.txtUpTheta.setEnabled(False)
146
147                if self.reader is not None and self.reader.isrunning():
148                    self.reader.stop()
149                self.cmdLoad.setEnabled(False)
150                self.cmdLoad.setText('Loading...')
151                self.communicator.statusBarUpdateSignal.emit(
152                    "Loading File {}".format(os.path.basename(
153                        str(self.datafile))))
154                self.reader = GenReader(path=str(self.datafile), loader=loader,
155                                        completefn=self.complete_loading,
156                                        updatefn=self.load_update)
157                self.reader.queue()
158        except (RuntimeError, IOError):
159            log_msg = "Generic SAS Calculator: %s" % sys.exc_value
160            logging.info(log_msg)
161            raise
162        return
163
164    def load_update(self):
165        """ Legacy function used in GenRead """
166        if self.reader.isrunning():
167            status_type = "progress"
168        else:
169            status_type = "stop"
170        logging.info(status_type)
171
172    def complete_loading(self, data=None):
173        """ Function used in GenRead"""
174        self.cbShape.setEnabled(False)
175        try:
176            is_pdbdata = False
177            self.txtData.setText(os.path.basename(str(self.datafile)))
178            self.is_avg = False
179            if self.ext in self.omf_reader.ext:
180                gen = sas_gen.OMF2SLD()
181                gen.set_data(data)
182                self.sld_data = gen.get_magsld()
183                self.check_units()
184            elif self.ext in self.sld_reader.ext:
185                self.sld_data = data
186            elif self.ext in self.pdb_reader.ext:
187                self.sld_data = data
188                is_pdbdata = True
189            # Display combobox of orientation only for pdb data
190            self.cbOptionsCalc.setVisible(is_pdbdata)
191            self.update_gui()
192        except IOError:
193            log_msg = "Loading Error: " \
194                      "This file format is not supported for GenSAS."
195            logging.info(log_msg)
196            raise
197        except ValueError:
198            log_msg = "Could not find any data"
199            logging.info(log_msg)
200            raise
201        logging.info("Load Complete")
202        self.cmdLoad.setEnabled(True)
203        self.cmdLoad.setText('Load')
204        self.trigger_plot_3d.emit()
205
206    def check_units(self):
207        """
208        Check if the units from the OMF file correspond to the default ones
209        displayed on the interface.
210        If not, modify the GUI with the correct unit
211        """
212        #  TODO: adopt the convention of font and symbol for the updated values
213        if sas_gen.OMFData().valueunit != 'A^(-2)':
214            value_unit = sas_gen.OMFData().valueunit
215            self.lbl_unitMx.setText(value_unit)
216            self.lbl_unitMy.setText(value_unit)
217            self.lbl_unitMz.setText(value_unit)
218            self.lbl_unitNucl.setText(value_unit)
219        if sas_gen.OMFData().meshunit != 'A':
220            mesh_unit = sas_gen.OMFData().meshunit
221            self.lbl_unitx.setText(mesh_unit)
222            self.lbl_unity.setText(mesh_unit)
223            self.lbl_unitz.setText(mesh_unit)
224            self.lbl_unitVolume.setText(mesh_unit+"^3")
225
226    def check_value(self):
227        """Check range of text edits for QMax and Number of Qbins """
228        text_edit = self.sender()
229        text_edit.setStyleSheet(
230            QtCore.QString.fromUtf8('background-color: rgb(255, 255, 255);'))
231        if text_edit.text():
232            value = float(str(text_edit.text()))
233            if text_edit == self.txtQxMax:
234                if value <= 0 or value > 1000:
235                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
236                        'background-color: rgb(255, 182, 193);'))
237                else:
238                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
239                        'background-color: rgb(255, 255, 255);'))
240            elif text_edit == self.txtNoQBins:
241                if value < 2 or value > 1000:
242                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
243                        'background-color: rgb(255, 182, 193);'))
244                else:
245                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
246                        'background-color: rgb(255, 255, 255);'))
247
248    def update_gui(self):
249        """ Update the interface with values from loaded data """
250        self.model.set_is_avg(self.is_avg)
251        self.model.set_sld_data(self.sld_data)
252        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
253
254        # add condition for activation of save button
255        self.cmdSave.setEnabled(True)
256
257        # activation of 3D plots' buttons (with and without arrows)
258        self.cmdDraw.setEnabled(self.sld_data is not None)
259        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
260
261        self.txtScale.setText(str(self.model.params['scale']))
262        self.txtBackground.setText(str(self.model.params['background']))
263        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
264
265        # Volume to write to interface: npts x volume of first pixel
266        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
267
268        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
269        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
270        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
271
272        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
273        self.txtNoPixels.setEnabled(False)
274
275        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
276                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
277                           'zstepsize']
278        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
279                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
280                           self.txtXstepsize, self.txtYstepsize,
281                           self.txtZstepsize]
282
283        # Fill right hand side of GUI
284        for indx, item in enumerate(list_parameters):
285            if getattr(self.sld_data, item) is None:
286                list_gui_button[indx].setText('NaN')
287            else:
288                value = getattr(self.sld_data, item)
289                if isinstance(value, numpy.ndarray):
290                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
291                else:
292                    item_for_gui = str(GuiUtils.formatNumber(value, True))
293                list_gui_button[indx].setText(item_for_gui)
294
295        # Enable / disable editing of right hand side of GUI
296        for indx, item in enumerate(list_parameters):
297            if indx < 4:
298                # this condition only applies to Mx,y,z and Nucl
299                value = getattr(self.sld_data, item)
300                enable = self.sld_data.pix_type == 'pixel' \
301                         and numpy.min(value) == numpy.max(value)
302            else:
303                enable = not self.sld_data.is_data
304            list_gui_button[indx].setEnabled(enable)
305
306    def write_new_values_from_gui(self):
307        """
308        update parameters using modified inputs from GUI
309        used before computing
310        """
311        if self.txtScale.isModified():
312            self.model.params['scale'] = float(self.txtScale.text())
313
314        if self.txtBackground.isModified():
315            self.model.params['background'] = float(self.txtBackground.text())
316
317        if self.txtSolventSLD.isModified():
318            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
319
320        # Different condition for total volume to get correct volume after
321        # applying set_sld_data in compute
322        if self.txtTotalVolume.isModified() \
323                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
324            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
325
326        if self.txtUpFracIn.isModified():
327            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
328
329        if self.txtUpFracOut.isModified():
330            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
331
332        if self.txtUpTheta.isModified():
333            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
334
335        if self.txtMx.isModified():
336            self.sld_data.sld_mx = float(self.txtMx.text())*\
337                                   numpy.ones(len(self.sld_data.sld_mx))
338
339        if self.txtMy.isModified():
340            self.sld_data.sld_my = float(self.txtMy.text())*\
341                                   numpy.ones(len(self.sld_data.sld_my))
342
343        if self.txtMz.isModified():
344            self.sld_data.sld_mz = float(self.txtMz.text())*\
345                                   numpy.ones(len(self.sld_data.sld_mz))
346
347        if self.txtNucl.isModified():
348            self.sld_data.sld_n = float(self.txtNucl.text())*\
349                                  numpy.ones(len(self.sld_data.sld_n))
350
351        if self.txtXnodes.isModified():
352            self.sld_data.xnodes = int(self.txtXnodes.text())
353
354        if self.txtYnodes.isModified():
355            self.sld_data.ynodes = int(self.txtYnodes.text())
356
357        if self.txtZnodes.isModified():
358            self.sld_data.znodes = int(self.txtZnodes.text())
359
360        if self.txtXstepsize.isModified():
361            self.sld_data.xstepsize = float(self.txtXstepsize.text())
362
363        if self.txtYstepsize.isModified():
364            self.sld_data.ystepsize = float(self.txtYstepsize.text())
365
366        if self.txtZstepsize.isModified():
367            self.sld_data.zstepsize = float(self.txtZstepsize.text())
368
369        if self.cbOptionsCalc.isVisible():
370            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
371
372    def onHelp(self):
373        """
374        Bring up the Generic Scattering calculator Documentation whenever
375        the HELP button is clicked.
376        Calls Documentation Window with the path of the location within the
377        documentation tree (after /doc/ ....".
378        """
379        try:
380            location = self.manager.HELP_DIRECTORY_LOCATION + \
381                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
382            self.manager._helpView.load(QtCore.QUrl(location))
383            self.manager._helpView.show()
384        except AttributeError:
385            # No manager defined - testing and standalone runs
386            pass
387
388    def onReset(self):
389        """ Reset the inputs of textEdit to default values """
390        try:
391            # reset values in textedits
392            self.txtUpFracIn.setText("1.0")
393            self.txtUpFracOut.setText("1.0")
394            self.txtUpTheta.setText("0.0")
395            self.txtBackground.setText("0.0")
396            self.txtScale.setText("1.0")
397            self.txtSolventSLD.setText("0.0")
398            self.txtTotalVolume.setText("216000.0")
399            self.txtNoQBins.setText("50")
400            self.txtQxMax.setText("0.3")
401            self.txtNoPixels.setText("1000")
402            self.txtMx.setText("0")
403            self.txtMy.setText("0")
404            self.txtMz.setText("0")
405            self.txtNucl.setText("6.97e-06")
406            self.txtXnodes.setText("10")
407            self.txtYnodes.setText("10")
408            self.txtZnodes.setText("10")
409            self.txtXstepsize.setText("6")
410            self.txtYstepsize.setText("6")
411            self.txtZstepsize.setText("6")
412            # reset Load button and textedit
413            self.txtData.setText('Default SLD Profile')
414            self.cmdLoad.setEnabled(True)
415            self.cmdLoad.setText('Load')
416            # reset option for calculation
417            self.cbOptionsCalc.setCurrentIndex(0)
418            self.cbOptionsCalc.setVisible(False)
419            # reset shape button
420            self.cbShape.setCurrentIndex(0)
421            self.cbShape.setEnabled(True)
422            # reset compute button
423            self.cmdCompute.setText('Compute')
424            self.cmdCompute.setEnabled(True)
425            # TODO reload default data set
426            self._create_default_sld_data()
427
428        finally:
429            pass
430
431    def _create_default_2d_data(self):
432        """
433        Copied from previous version
434        Create 2D data by default
435        :warning: This data is never plotted.
436        """
437        self.qmax_x = float(self.txtQxMax.text())
438        self.npts_x = int(self.txtNoQBins.text())
439        self.data = Data2D()
440        self.data.is_data = False
441        # # Default values
442        self.data.detector.append(Detector())
443        index = len(self.data.detector) - 1
444        self.data.detector[index].distance = 8000  # mm
445        self.data.source.wavelength = 6  # A
446        self.data.detector[index].pixel_size.x = 5  # mm
447        self.data.detector[index].pixel_size.y = 5  # mm
448        self.data.detector[index].beam_center.x = self.qmax_x
449        self.data.detector[index].beam_center.y = self.qmax_x
450        xmax = self.qmax_x
451        xmin = -xmax
452        ymax = self.qmax_x
453        ymin = -ymax
454        qstep = self.npts_x
455
456        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
457        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
458        # use data info instead
459        new_x = numpy.tile(x, (len(y), 1))
460        new_y = numpy.tile(y, (len(x), 1))
461        new_y = new_y.swapaxes(0, 1)
462        # all data require now in 1d array
463        qx_data = new_x.flatten()
464        qy_data = new_y.flatten()
465        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
466        # set all True (standing for unmasked) as default
467        mask = numpy.ones(len(qx_data), dtype=bool)
468        self.data.source = Source()
469        self.data.data = numpy.ones(len(mask))
470        self.data.err_data = numpy.ones(len(mask))
471        self.data.qx_data = qx_data
472        self.data.qy_data = qy_data
473        self.data.q_data = q_data
474        self.data.mask = mask
475        # store x and y bin centers in q space
476        self.data.x_bins = x
477        self.data.y_bins = y
478        # max and min taking account of the bin sizes
479        self.data.xmin = xmin
480        self.data.xmax = xmax
481        self.data.ymin = ymin
482        self.data.ymax = ymax
483
484    def _create_default_sld_data(self):
485        """
486        Copied from previous version
487        Making default sld-data
488        """
489        sld_n_default = 6.97e-06  # what is this number??
490        omfdata = sas_gen.OMFData()
491        omf2sld = sas_gen.OMF2SLD()
492        omf2sld.set_data(omfdata, self.default_shape)
493        self.sld_data = omf2sld.output
494        self.sld_data.is_data = False
495        self.sld_data.filename = "Default SLD Profile"
496        self.sld_data.set_sldn(sld_n_default)
497
498    def _create_default_1d_data(self):
499        """
500        Copied from previous version
501        Create 1D data by default
502        :warning: This data is never plotted.
503                    residuals.x = data_copy.x[index]
504            residuals.dy = numpy.ones(len(residuals.y))
505            residuals.dx = None
506            residuals.dxl = None
507            residuals.dxw = None
508        """
509        self.qmax_x = float(self.txtQxMax.text())
510        self.npts_x = int(self.txtNoQBins.text())
511        #  Default values
512        xmax = self.qmax_x
513        xmin = self.qmax_x * _Q1D_MIN
514        qstep = self.npts_x
515        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
516        # store x and y bin centers in q space
517        y = numpy.ones(len(x))
518        dy = numpy.zeros(len(x))
519        dx = numpy.zeros(len(x))
520        self.data = Data1D(x=x, y=y)
521        self.data.dx = dx
522        self.data.dy = dy
523
524    def onCompute(self):
525        """
526        Copied from previous version
527        Execute the computation of I(qx, qy)
528        """
529        # Set default data when nothing loaded yet
530        if self.sld_data is None:
531            self._create_default_sld_data()
532        try:
533            self.model.set_sld_data(self.sld_data)
534            self.write_new_values_from_gui()
535            if self.is_avg or self.is_avg is None:
536                self._create_default_1d_data()
537                i_out = numpy.zeros(len(self.data.y))
538                inputs = [self.data.x, [], i_out]
539            else:
540                self._create_default_2d_data()
541                i_out = numpy.zeros(len(self.data.data))
542                inputs = [self.data.qx_data, self.data.qy_data, i_out]
543            logging.info("Computation is in progress...")
544            self.cmdCompute.setText('Wait...')
545            self.cmdCompute.setEnabled(False)
546            d = threads.deferToThread(self.complete, inputs, self._update)
547            # Add deferred callback for call return
548            d.addCallback(self.plot_1_2d)
549        except:
550            log_msg = "{}. stop".format(sys.exc_value)
551            logging.info(log_msg)
552        return
553
554    def _update(self, value):
555        """
556        Copied from previous version
557        """
558        pass
559
560    def complete(self, input, update=None):
561        """
562        Gen compute complete function
563        :Param input: input list [qx_data, qy_data, i_out]
564        """
565        out = numpy.empty(0)
566        for ind in range(len(input[0])):
567            if self.is_avg:
568                if ind % 1 == 0 and update is not None:
569                    # update()
570                    percentage = int(100.0 * float(ind) / len(input[0]))
571                    update(percentage)
572                    time.sleep(0.001)  # 0.1
573                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
574                outi = self.model.run(inputi)
575                out = numpy.append(out, outi)
576            else:
577                if ind % 50 == 0 and update is not None:
578                    percentage = int(100.0 * float(ind) / len(input[0]))
579                    update(percentage)
580                    time.sleep(0.001)
581                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
582                          input[2][ind:ind + 1]]
583                outi = self.model.runXY(inputi)
584                out = numpy.append(out, outi)
585        self.data_to_plot = out
586        logging.info('Gen computation completed.')
587        self.cmdCompute.setText('Compute')
588        self.cmdCompute.setEnabled(True)
589        return
590
591    def onSaveFile(self):
592        """Save data as .sld file"""
593        path = os.path.dirname(str(self.datafile))
594        default_name = os.path.join(path, 'sld_file')
595        kwargs = {
596            'parent': self,
597            'directory': default_name,
598            'filter': 'SLD file (*.sld)',
599            'options': QtGui.QFileDialog.DontUseNativeDialog}
600        # Query user for filename.
601        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
602        if filename:
603            try:
604                if os.path.splitext(filename)[1].lower() == '.sld':
605                    sas_gen.SLDReader().write(filename, self.sld_data)
606                else:
607                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
608                                              self.sld_data)
609            except:
610                raise
611
612    def plot3d(self, has_arrow=False):
613        """ Generate 3D plot in real space with or without arrows """
614        self.write_new_values_from_gui()
615        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
616                                                            self.file_name)
617        if has_arrow:
618            graph_title += ' - Magnetic Vector as Arrow'
619
620        plot3D = Plotter3D(self, graph_title)
621        plot3D.plot(self.sld_data, has_arrow=has_arrow)
622        plot3D.show()
623        self.graph_num += 1
624
625    def plot_1_2d(self, d):
626        """ Generate 1D or 2D plot, called in Compute"""
627        if self.is_avg or self.is_avg is None:
628            data = Data1D(x=self.data.x, y=self.data_to_plot)
629            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
630                                                    int(self.graph_num))
631            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
632            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
633            plot1D = Plotter(self)
634            plot1D.plot(data)
635            plot1D.show()
636            self.graph_num += 1
637            # TODO
638            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
639            return plot1D
640        else:
641            numpy.nan_to_num(self.data_to_plot)
642            data = Data2D(image=self.data_to_plot,
643                          qx_data=self.data.qx_data,
644                          qy_data=self.data.qy_data,
645                          q_data=self.data.q_data,
646                          xmin=self.data.xmin, xmax=self.data.ymax,
647                          ymin=self.data.ymin, ymax=self.data.ymax,
648                          err_image=self.data.err_data)
649            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
650                                                    int(self.graph_num))
651            plot2D = Plotter2D(self)
652            plot2D.plot(data)
653            plot2D.show()
654            self.graph_num += 1
655            # TODO
656            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
657            return plot2D
658
659
660class Plotter3DWidget(PlotterBase):
661    """
662    3D Plot widget for use with a QDialog
663    """
664    def __init__(self, parent=None, manager=None):
665        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
666
667    @property
668    def data(self):
669        return self._data
670
671    @data.setter
672    def data(self, data=None):
673        """ data setter """
674        self._data = data
675
676    def plot(self, data=None, has_arrow=False):
677        """
678        Plot 3D self._data
679        """
680        self.data = data
681        assert(self._data)
682        # Prepare and show the plot
683        self.showPlot(data=self.data, has_arrow=has_arrow)
684
685    def showPlot(self, data, has_arrow=False):
686        """
687        Render and show the current data
688        """
689        # If we don't have any data, skip.
690        if data is None:
691            return
692        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
693                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
694        marker = ','
695        m_size = 2
696
697        pos_x = data.pos_x
698        pos_y = data.pos_y
699        pos_z = data.pos_z
700        sld_mx = data.sld_mx
701        sld_my = data.sld_my
702        sld_mz = data.sld_mz
703        pix_symbol = data.pix_symbol
704        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
705                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
706        is_nonzero = sld_tot > 0.0
707        is_zero = sld_tot == 0.0
708
709        if data.pix_type == 'atom':
710            marker = 'o'
711            m_size = 3.5
712
713        self.figure.clear()
714        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
715        ax = Axes3D(self.figure)
716        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
717        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
718        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
719
720        # I. Plot null points
721        if is_zero.any():
722            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
723                           marker, c="y", alpha=0.5, markeredgecolor='y',
724                           markersize=m_size)
725            pos_x = pos_x[is_nonzero]
726            pos_y = pos_y[is_nonzero]
727            pos_z = pos_z[is_nonzero]
728            sld_mx = sld_mx[is_nonzero]
729            sld_my = sld_my[is_nonzero]
730            sld_mz = sld_mz[is_nonzero]
731            pix_symbol = data.pix_symbol[is_nonzero]
732        # II. Plot selective points in color
733        other_color = numpy.ones(len(pix_symbol), dtype='bool')
734        for key in color_dic.keys():
735            chosen_color = pix_symbol == key
736            if numpy.any(chosen_color):
737                other_color = other_color & (chosen_color!=True)
738                color = color_dic[key]
739                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
740                         pos_y[chosen_color], marker, c=color, alpha=0.5,
741                         markeredgecolor=color, markersize=m_size, label=key)
742        # III. Plot All others
743        if numpy.any(other_color):
744            a_name = ''
745            if data.pix_type == 'atom':
746                # Get atom names not in the list
747                a_names = [symb for symb in pix_symbol \
748                           if symb not in color_dic.keys()]
749                a_name = a_names[0]
750                for name in a_names:
751                    new_name = ", " + name
752                    if a_name.count(name) == 0:
753                        a_name += new_name
754            # plot in black
755            im = ax.plot(pos_x[other_color], pos_z[other_color],
756                         pos_y[other_color], marker, c="k", alpha=0.5,
757                         markeredgecolor="k", markersize=m_size, label=a_name)
758        if data.pix_type == 'atom':
759            ax.legend(loc='upper left', prop={'size': 10})
760        # IV. Draws atomic bond with grey lines if any
761        if data.has_conect:
762            for ind in range(len(data.line_x)):
763                im = ax.plot(data.line_x[ind], data.line_z[ind],
764                             data.line_y[ind], '-', lw=0.6, c="grey",
765                             alpha=0.3)
766        # V. Draws magnetic vectors
767        if has_arrow and len(pos_x) > 0:
768            def _draw_arrow(input=None, update=None):
769                """
770                draw magnetic vectors w/arrow
771                """
772                max_mx = max(numpy.fabs(sld_mx))
773                max_my = max(numpy.fabs(sld_my))
774                max_mz = max(numpy.fabs(sld_mz))
775                max_m = max(max_mx, max_my, max_mz)
776                try:
777                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
778                except:
779                    max_step = 0
780                if max_step <= 0:
781                    max_step = 5
782                try:
783                    if max_m != 0:
784                        unit_x2 = sld_mx / max_m
785                        unit_y2 = sld_my / max_m
786                        unit_z2 = sld_mz / max_m
787                        # 0.8 is for avoiding the color becomes white=(1,1,1))
788                        color_x = numpy.fabs(unit_x2 * 0.8)
789                        color_y = numpy.fabs(unit_y2 * 0.8)
790                        color_z = numpy.fabs(unit_z2 * 0.8)
791                        x2 = pos_x + unit_x2 * max_step
792                        y2 = pos_y + unit_y2 * max_step
793                        z2 = pos_z + unit_z2 * max_step
794                        x_arrow = numpy.column_stack((pos_x, x2))
795                        y_arrow = numpy.column_stack((pos_y, y2))
796                        z_arrow = numpy.column_stack((pos_z, z2))
797                        colors = numpy.column_stack((color_x, color_y, color_z))
798                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
799                                        colors, mutation_scale=10, lw=1,
800                                        arrowstyle="->", alpha=0.5)
801                        ax.add_artist(arrows)
802                except:
803                    pass
804                log_msg = "Arrow Drawing completed.\n"
805                logging.info(log_msg)
806            log_msg = "Arrow Drawing is in progress..."
807            logging.info(log_msg)
808
809            # Defer the drawing of arrows to another thread
810            d = threads.deferToThread(_draw_arrow, ax)
811
812        self.figure.canvas.resizing = False
813        self.figure.canvas.draw()
814
815
816class Plotter3D(QtGui.QDialog, Plotter3DWidget):
817    def __init__(self, parent=None, graph_title=''):
818        self.graph_title = graph_title
819        QtGui.QDialog.__init__(self)
820        Plotter3DWidget.__init__(self, manager=parent)
821        self.setWindowTitle(self.graph_title)
822
Note: See TracBrowser for help on using the repository browser.