source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 28a09b0

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 28a09b0 was 28a09b0, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Merged Celine's implementation of the generic scattering calculator

  • Property mode set to 100755
File size: 33.5 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10
11import sas.qtgui.Utilities.GuiUtils as GuiUtils
12
13from sas.sasgui.perspectives.calculator.load_thread import GenReader
14from sas.sascalc.dataloader.data_info import Detector
15from sas.sascalc.dataloader.data_info import Source
16from sas.sascalc.calculator import sas_gen
17
18from mpl_toolkits.mplot3d import Axes3D
19from sas.qtgui.Plotting.PlotterBase import PlotterBase
20from sas.sasgui.guiframe.dataFitting import Data2D
21from sas.qtgui.Plotting.Plotter2D import Plotter2D
22from sas.sasgui.guiframe.dataFitting import Data1D
23from sas.qtgui.Plotting.Plotter import Plotter
24from sas.sasgui.plottools.arrow3d import Arrow3D
25
26# Local UI
27from UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
28
29_Q1D_MIN = 0.001
30
31
32class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
33
34    trigger_plot_3d = QtCore.pyqtSignal()
35
36    def __init__(self, parent=None):
37        super(GenericScatteringCalculator, self).__init__()
38        self.setupUi(self)
39        self.manager = parent
40        self.communicator = self.manager.communicator()
41        self.model = sas_gen.GenSAS()
42        self.omf_reader = sas_gen.OMFReader()
43        self.sld_reader = sas_gen.SLDReader()
44        self.pdb_reader = sas_gen.PDBReader()
45        self.reader = None
46        self.sld_data = None
47
48        self.parameters = []
49        self.data = None
50        self.datafile = None
51        self.file_name = ''
52        self.ext = None
53        self.default_shape = str(self.cbShape.currentText())
54        self.is_avg = False
55        self.data_to_plot = None
56
57        # combox box
58        self.cbOptionsCalc.setVisible(False)
59
60        # push buttons
61        self.cmdClose.clicked.connect(self.accept)
62        self.cmdHelp.clicked.connect(self.onHelp)
63
64        self.cmdLoad.clicked.connect(self.loadFile)
65        self.cmdCompute.clicked.connect(self.onCompute)
66        self.cmdReset.clicked.connect(self.onReset)
67        self.cmdSave.clicked.connect(self.onSaveFile)
68
69        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
70        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
71
72        # validators
73        # scale, volume and background must be positive
74        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
75        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
76                                                          self.txtScale))
77        self.txtBackground.setValidator(QtGui.QRegExpValidator(
78            validat_regex_pos, self.txtBackground))
79        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtTotalVolume))
81
82        # fraction of spin up between 0 and 1
83        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
84        self.txtUpFracIn.setValidator(
85            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
86        self.txtUpFracOut.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
88
89        # 0 < Qmax <= 1000
90        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
91        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
92                                                          self.txtQxMax))
93        self.txtQxMax.textChanged.connect(self.check_value)
94
95        # 2 <= Qbin <= 1000
96        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
97                                                            self.txtNoQBins))
98        self.txtNoQBins.textChanged.connect(self.check_value)
99
100        # plots - 3D in real space
101        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
102
103        self.graph_num = 1  # index for name of graph
104
105        # TODO the option Ellipsoid has not been implemented
106        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
107
108    def selectedshapechange(self):
109        """
110        TODO Temporary solution to display information about option 'Ellipsoid'
111        """
112        print "The option Ellipsoid has not been implemented yet."
113        self.communicator.statusBarUpdateSignal.emit(
114            "The option Ellipsoid has not been implemented yet.")
115
116    def loadFile(self):
117        """
118        Open menu to choose the datafile to load
119        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
120        """
121        try:
122            self.datafile = QtGui.QFileDialog.getOpenFileName(
123                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
124                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
125                                          "OMF files (*.OMF *.omf);; "
126                                          "All files (*.*)")
127            if self.datafile:
128                self.default_shape = str(self.cbShape.currentText())
129                self.file_name = os.path.basename(str(self.datafile))
130                self.ext = os.path.splitext(str(self.datafile))[1]
131                if self.ext in self.omf_reader.ext:
132                    loader = self.omf_reader
133                elif self.ext in self.sld_reader.ext:
134                    loader = self.sld_reader
135                elif self.ext in self.pdb_reader.ext:
136                    loader = self.pdb_reader
137                else:
138                    loader = None
139                # disable some entries depending on type of loaded file
140                # (according to documentation)
141                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
142                    self.txtUpFracIn.setEnabled(False)
143                    self.txtUpFracOut.setEnabled(False)
144                    self.txtUpTheta.setEnabled(False)
145
146                if self.reader is not None and self.reader.isrunning():
147                    self.reader.stop()
148                self.cmdLoad.setEnabled(False)
149                self.cmdLoad.setText('Loading...')
150                self.communicator.statusBarUpdateSignal.emit(
151                    "Loading File {}".format(os.path.basename(
152                        str(self.datafile))))
153                self.reader = GenReader(path=str(self.datafile), loader=loader,
154                                        completefn=self.complete_loading,
155                                        updatefn=self.load_update)
156                self.reader.queue()
157        except (RuntimeError, IOError):
158            log_msg = "Generic SAS Calculator: %s" % sys.exc_value
159            logging.info(log_msg)
160            raise
161        return
162
163    def load_update(self):
164        """ Legacy function used in GenRead """
165        if self.reader.isrunning():
166            status_type = "progress"
167        else:
168            status_type = "stop"
169        logging.info(status_type)
170
171    def complete_loading(self, data=None):
172        """ Function used in GenRead"""
173        self.cbShape.setEnabled(False)
174        try:
175            is_pdbdata = False
176            self.txtData.setText(os.path.basename(str(self.datafile)))
177            self.is_avg = False
178            if self.ext in self.omf_reader.ext:
179                gen = sas_gen.OMF2SLD()
180                gen.set_data(data)
181                self.sld_data = gen.get_magsld()
182                self.check_units()
183            elif self.ext in self.sld_reader.ext:
184                self.sld_data = data
185            elif self.ext in self.pdb_reader.ext:
186                self.sld_data = data
187                is_pdbdata = True
188            # Display combobox of orientation only for pdb data
189            self.cbOptionsCalc.setVisible(is_pdbdata)
190            self.update_gui()
191        except IOError:
192            log_msg = "Loading Error: " \
193                      "This file format is not supported for GenSAS."
194            logging.info(log_msg)
195            raise
196        except ValueError:
197            log_msg = "Could not find any data"
198            logging.info(log_msg)
199            raise
200        logging.info("Load Complete")
201        self.cmdLoad.setEnabled(True)
202        self.cmdLoad.setText('Load')
203        self.trigger_plot_3d.emit()
204
205    def check_units(self):
206        """
207        Check if the units from the OMF file correspond to the default ones
208        displayed on the interface.
209        If not, modify the GUI with the correct unit
210        """
211        #  TODO: adopt the convention of font and symbol for the updated values
212        if sas_gen.OMFData().valueunit != 'A^(-2)':
213            value_unit = sas_gen.OMFData().valueunit
214            self.lbl_unitMx.setText(value_unit)
215            self.lbl_unitMy.setText(value_unit)
216            self.lbl_unitMz.setText(value_unit)
217            self.lbl_unitNucl.setText(value_unit)
218        if sas_gen.OMFData().meshunit != 'A':
219            mesh_unit = sas_gen.OMFData().meshunit
220            self.lbl_unitx.setText(mesh_unit)
221            self.lbl_unity.setText(mesh_unit)
222            self.lbl_unitz.setText(mesh_unit)
223            self.lbl_unitVolume.setText(mesh_unit+"^3")
224
225    def check_value(self):
226        """Check range of text edits for QMax and Number of Qbins """
227        text_edit = self.sender()
228        text_edit.setStyleSheet(
229            QtCore.QString.fromUtf8('background-color: rgb(255, 255, 255);'))
230        if text_edit.text():
231            value = float(str(text_edit.text()))
232            if text_edit == self.txtQxMax:
233                if value <= 0 or value > 1000:
234                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
235                        'background-color: rgb(255, 182, 193);'))
236                else:
237                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
238                        'background-color: rgb(255, 255, 255);'))
239            elif text_edit == self.txtNoQBins:
240                if value < 2 or value > 1000:
241                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
242                        'background-color: rgb(255, 182, 193);'))
243                else:
244                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
245                        'background-color: rgb(255, 255, 255);'))
246
247    def update_gui(self):
248        """ Update the interface with values from loaded data """
249        self.model.set_is_avg(self.is_avg)
250        self.model.set_sld_data(self.sld_data)
251        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
252
253        # add condition for activation of save button
254        self.cmdSave.setEnabled(True)
255
256        # activation of 3D plots' buttons (with and without arrows)
257        self.cmdDraw.setEnabled(self.sld_data is not None)
258        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
259
260        self.txtScale.setText(str(self.model.params['scale']))
261        self.txtBackground.setText(str(self.model.params['background']))
262        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
263
264        # Volume to write to interface: npts x volume of first pixel
265        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
266
267        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
268        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
269        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
270
271        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
272        self.txtNoPixels.setEnabled(False)
273
274        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
275                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
276                           'zstepsize']
277        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
278                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
279                           self.txtXstepsize, self.txtYstepsize,
280                           self.txtZstepsize]
281
282        # Fill right hand side of GUI
283        for indx, item in enumerate(list_parameters):
284            if getattr(self.sld_data, item) is None:
285                list_gui_button[indx].setText('NaN')
286            else:
287                value = getattr(self.sld_data, item)
288                if isinstance(value, numpy.ndarray):
289                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
290                else:
291                    item_for_gui = str(GuiUtils.formatNumber(value, True))
292                list_gui_button[indx].setText(item_for_gui)
293
294        # Enable / disable editing of right hand side of GUI
295        for indx, item in enumerate(list_parameters):
296            if indx < 4:
297                # this condition only applies to Mx,y,z and Nucl
298                value = getattr(self.sld_data, item)
299                enable = self.sld_data.pix_type == 'pixel' \
300                         and numpy.min(value) == numpy.max(value)
301            else:
302                enable = not self.sld_data.is_data
303            list_gui_button[indx].setEnabled(enable)
304
305    def write_new_values_from_gui(self):
306        """
307        update parameters using modified inputs from GUI
308        used before computing
309        """
310        if self.txtScale.isModified():
311            self.model.params['scale'] = float(self.txtScale.text())
312
313        if self.txtBackground.isModified():
314            self.model.params['background'] = float(self.txtBackground.text())
315
316        if self.txtSolventSLD.isModified():
317            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
318
319        # Different condition for total volume to get correct volume after
320        # applying set_sld_data in compute
321        if self.txtTotalVolume.isModified() \
322                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
323            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
324
325        if self.txtUpFracIn.isModified():
326            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
327
328        if self.txtUpFracOut.isModified():
329            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
330
331        if self.txtUpTheta.isModified():
332            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
333
334        if self.txtMx.isModified():
335            self.sld_data.sld_mx = float(self.txtMx.text())*\
336                                   numpy.ones(len(self.sld_data.sld_mx))
337
338        if self.txtMy.isModified():
339            self.sld_data.sld_my = float(self.txtMy.text())*\
340                                   numpy.ones(len(self.sld_data.sld_my))
341
342        if self.txtMz.isModified():
343            self.sld_data.sld_mz = float(self.txtMz.text())*\
344                                   numpy.ones(len(self.sld_data.sld_mz))
345
346        if self.txtNucl.isModified():
347            self.sld_data.sld_n = float(self.txtNucl.text())*\
348                                  numpy.ones(len(self.sld_data.sld_n))
349
350        if self.txtXnodes.isModified():
351            self.sld_data.xnodes = int(self.txtXnodes.text())
352
353        if self.txtYnodes.isModified():
354            self.sld_data.ynodes = int(self.txtYnodes.text())
355
356        if self.txtZnodes.isModified():
357            self.sld_data.znodes = int(self.txtZnodes.text())
358
359        if self.txtXstepsize.isModified():
360            self.sld_data.xstepsize = float(self.txtXstepsize.text())
361
362        if self.txtYstepsize.isModified():
363            self.sld_data.ystepsize = float(self.txtYstepsize.text())
364
365        if self.txtZstepsize.isModified():
366            self.sld_data.zstepsize = float(self.txtZstepsize.text())
367
368        if self.cbOptionsCalc.isVisible():
369            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
370
371    def onHelp(self):
372        """
373        Bring up the Generic Scattering calculator Documentation whenever
374        the HELP button is clicked.
375        Calls Documentation Window with the path of the location within the
376        documentation tree (after /doc/ ....".
377        """
378        try:
379            location = self.manager.HELP_DIRECTORY_LOCATION + \
380                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
381            self.manager._helpView.load(QtCore.QUrl(location))
382            self.manager._helpView.show()
383        except AttributeError:
384            # No manager defined - testing and standalone runs
385            pass
386
387    def onReset(self):
388        """ Reset the inputs of textEdit to default values """
389        try:
390            # reset values in textedits
391            self.txtUpFracIn.setText("1.0")
392            self.txtUpFracOut.setText("1.0")
393            self.txtUpTheta.setText("0.0")
394            self.txtBackground.setText("0.0")
395            self.txtScale.setText("1.0")
396            self.txtSolventSLD.setText("0.0")
397            self.txtTotalVolume.setText("216000.0")
398            self.txtNoQBins.setText("50")
399            self.txtQxMax.setText("0.3")
400            self.txtNoPixels.setText("1000")
401            self.txtMx.setText("0")
402            self.txtMy.setText("0")
403            self.txtMz.setText("0")
404            self.txtNucl.setText("6.97e-06")
405            self.txtXnodes.setText("10")
406            self.txtYnodes.setText("10")
407            self.txtZnodes.setText("10")
408            self.txtXstepsize.setText("6")
409            self.txtYstepsize.setText("6")
410            self.txtZstepsize.setText("6")
411            # reset Load button and textedit
412            self.txtData.setText('Default SLD Profile')
413            self.cmdLoad.setEnabled(True)
414            self.cmdLoad.setText('Load')
415            # reset option for calculation
416            self.cbOptionsCalc.setCurrentIndex(0)
417            self.cbOptionsCalc.setVisible(False)
418            # reset shape button
419            self.cbShape.setCurrentIndex(0)
420            self.cbShape.setEnabled(True)
421            # reset compute button
422            self.cmdCompute.setText('Compute')
423            self.cmdCompute.setEnabled(True)
424            # TODO reload default data set
425            self._create_default_sld_data()
426
427        finally:
428            pass
429
430    def _create_default_2d_data(self):
431        """
432        Copied from previous version
433        Create 2D data by default
434        :warning: This data is never plotted.
435        """
436        self.qmax_x = float(self.txtQxMax.text())
437        self.npts_x = int(self.txtNoQBins.text())
438        self.data = Data2D()
439        self.data.is_data = False
440        # # Default values
441        self.data.detector.append(Detector())
442        index = len(self.data.detector) - 1
443        self.data.detector[index].distance = 8000  # mm
444        self.data.source.wavelength = 6  # A
445        self.data.detector[index].pixel_size.x = 5  # mm
446        self.data.detector[index].pixel_size.y = 5  # mm
447        self.data.detector[index].beam_center.x = self.qmax_x
448        self.data.detector[index].beam_center.y = self.qmax_x
449        xmax = self.qmax_x
450        xmin = -xmax
451        ymax = self.qmax_x
452        ymin = -ymax
453        qstep = self.npts_x
454
455        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
456        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
457        # use data info instead
458        new_x = numpy.tile(x, (len(y), 1))
459        new_y = numpy.tile(y, (len(x), 1))
460        new_y = new_y.swapaxes(0, 1)
461        # all data require now in 1d array
462        qx_data = new_x.flatten()
463        qy_data = new_y.flatten()
464        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
465        # set all True (standing for unmasked) as default
466        mask = numpy.ones(len(qx_data), dtype=bool)
467        self.data.source = Source()
468        self.data.data = numpy.ones(len(mask))
469        self.data.err_data = numpy.ones(len(mask))
470        self.data.qx_data = qx_data
471        self.data.qy_data = qy_data
472        self.data.q_data = q_data
473        self.data.mask = mask
474        # store x and y bin centers in q space
475        self.data.x_bins = x
476        self.data.y_bins = y
477        # max and min taking account of the bin sizes
478        self.data.xmin = xmin
479        self.data.xmax = xmax
480        self.data.ymin = ymin
481        self.data.ymax = ymax
482
483    def _create_default_sld_data(self):
484        """
485        Copied from previous version
486        Making default sld-data
487        """
488        sld_n_default = 6.97e-06  # what is this number??
489        omfdata = sas_gen.OMFData()
490        omf2sld = sas_gen.OMF2SLD()
491        omf2sld.set_data(omfdata, self.default_shape)
492        self.sld_data = omf2sld.output
493        self.sld_data.is_data = False
494        self.sld_data.filename = "Default SLD Profile"
495        self.sld_data.set_sldn(sld_n_default)
496
497    def _create_default_1d_data(self):
498        """
499        Copied from previous version
500        Create 1D data by default
501        :warning: This data is never plotted.
502                    residuals.x = data_copy.x[index]
503            residuals.dy = numpy.ones(len(residuals.y))
504            residuals.dx = None
505            residuals.dxl = None
506            residuals.dxw = None
507        """
508        self.qmax_x = float(self.txtQxMax.text())
509        self.npts_x = int(self.txtNoQBins.text())
510        #  Default values
511        xmax = self.qmax_x
512        xmin = self.qmax_x * _Q1D_MIN
513        qstep = self.npts_x
514        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
515        # store x and y bin centers in q space
516        y = numpy.ones(len(x))
517        dy = numpy.zeros(len(x))
518        dx = numpy.zeros(len(x))
519        self.data = Data1D(x=x, y=y)
520        self.data.dx = dx
521        self.data.dy = dy
522
523    def onCompute(self):
524        """
525        Copied from previous version
526        Execute the computation of I(qx, qy)
527        """
528        # Set default data when nothing loaded yet
529        if self.sld_data is None:
530            self._create_default_sld_data()
531        try:
532            self.model.set_sld_data(self.sld_data)
533            self.write_new_values_from_gui()
534            if self.is_avg or self.is_avg is None:
535                self._create_default_1d_data()
536                i_out = numpy.zeros(len(self.data.y))
537                inputs = [self.data.x, [], i_out]
538            else:
539                self._create_default_2d_data()
540                i_out = numpy.zeros(len(self.data.data))
541                inputs = [self.data.qx_data, self.data.qy_data, i_out]
542            logging.info("Computation is in progress...")
543            self.cmdCompute.setText('Wait...')
544            self.cmdCompute.setEnabled(False)
545            d = threads.deferToThread(self.complete, inputs, self._update)
546            # Add deferred callback for call return
547            d.addCallback(self.plot_1_2d)
548        except:
549            log_msg = "{}. stop".format(sys.exc_value)
550            logging.info(log_msg)
551        return
552
553    def _update(self, value):
554        """
555        Copied from previous version
556        """
557        pass
558
559    def complete(self, input, update=None):
560        """
561        Gen compute complete function
562        :Param input: input list [qx_data, qy_data, i_out]
563        """
564        out = numpy.empty(0)
565        for ind in range(len(input[0])):
566            if self.is_avg:
567                if ind % 1 == 0 and update is not None:
568                    # update()
569                    percentage = int(100.0 * float(ind) / len(input[0]))
570                    update(percentage)
571                    time.sleep(0.001)  # 0.1
572                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
573                outi = self.model.run(inputi)
574                out = numpy.append(out, outi)
575            else:
576                if ind % 50 == 0 and update is not None:
577                    percentage = int(100.0 * float(ind) / len(input[0]))
578                    update(percentage)
579                    time.sleep(0.001)
580                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
581                          input[2][ind:ind + 1]]
582                outi = self.model.runXY(inputi)
583                out = numpy.append(out, outi)
584        self.data_to_plot = out
585        logging.info('Gen computation completed.')
586        self.cmdCompute.setText('Compute')
587        self.cmdCompute.setEnabled(True)
588        return
589
590    def onSaveFile(self):
591        """Save data as .sld file"""
592        path = os.path.dirname(str(self.datafile))
593        default_name = os.path.join(path, 'sld_file')
594        kwargs = {
595            'parent': self,
596            'directory': default_name,
597            'filter': 'SLD file (*.sld)',
598            'options': QtGui.QFileDialog.DontUseNativeDialog}
599        # Query user for filename.
600        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
601        if filename:
602            try:
603                if os.path.splitext(filename)[1].lower() == '.sld':
604                    sas_gen.SLDReader().write(filename, self.sld_data)
605                else:
606                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
607                                              self.sld_data)
608            except:
609                raise
610
611    def plot3d(self, has_arrow=False):
612        """ Generate 3D plot in real space with or without arrows """
613        self.write_new_values_from_gui()
614        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
615                                                            self.file_name)
616        if has_arrow:
617            graph_title += ' - Magnetic Vector as Arrow'
618
619        plot3D = Plotter3D(self, graph_title)
620        plot3D.plot(self.sld_data, has_arrow=has_arrow)
621        plot3D.show()
622        self.graph_num += 1
623
624    def plot_1_2d(self, d):
625        """ Generate 1D or 2D plot, called in Compute"""
626        if self.is_avg or self.is_avg is None:
627            data = Data1D(x=self.data.x, y=self.data_to_plot)
628            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
629                                                    int(self.graph_num))
630            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
631            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
632            plot1D = Plotter(self)
633            plot1D.plot(data)
634            plot1D.show()
635            self.graph_num += 1
636            # TODO
637            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
638            return plot1D
639        else:
640            numpy.nan_to_num(self.data_to_plot)
641            data = Data2D(image=self.data_to_plot,
642                          qx_data=self.data.qx_data,
643                          qy_data=self.data.qy_data,
644                          q_data=self.data.q_data,
645                          xmin=self.data.xmin, xmax=self.data.ymax,
646                          ymin=self.data.ymin, ymax=self.data.ymax,
647                          err_image=self.data.err_data)
648            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
649                                                    int(self.graph_num))
650            plot2D = Plotter2D(self)
651            plot2D.plot(data)
652            plot2D.show()
653            self.graph_num += 1
654            # TODO
655            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
656            return plot2D
657
658
659class Plotter3DWidget(PlotterBase):
660    """
661    3D Plot widget for use with a QDialog
662    """
663    def __init__(self, parent=None, manager=None):
664        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
665
666    @property
667    def data(self):
668        return self._data
669
670    @data.setter
671    def data(self, data=None):
672        """ data setter """
673        self._data = data
674
675    def plot(self, data=None, has_arrow=False):
676        """
677        Plot 3D self._data
678        """
679        self.data = data
680        assert(self._data)
681        # Prepare and show the plot
682        self.showPlot(data=self.data, has_arrow=has_arrow)
683
684    def showPlot(self, data, has_arrow=False):
685        """
686        Render and show the current data
687        """
688        # If we don't have any data, skip.
689        if data is None:
690            return
691        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
692                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
693        marker = ','
694        m_size = 2
695
696        pos_x = data.pos_x
697        pos_y = data.pos_y
698        pos_z = data.pos_z
699        sld_mx = data.sld_mx
700        sld_my = data.sld_my
701        sld_mz = data.sld_mz
702        pix_symbol = data.pix_symbol
703        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
704                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
705        is_nonzero = sld_tot > 0.0
706        is_zero = sld_tot == 0.0
707
708        if data.pix_type == 'atom':
709            marker = 'o'
710            m_size = 3.5
711
712        self.figure.clear()
713        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
714        ax = Axes3D(self.figure)
715        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
716        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
717        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
718
719        # I. Plot null points
720        if is_zero.any():
721            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
722                           marker, c="y", alpha=0.5, markeredgecolor='y',
723                           markersize=m_size)
724            pos_x = pos_x[is_nonzero]
725            pos_y = pos_y[is_nonzero]
726            pos_z = pos_z[is_nonzero]
727            sld_mx = sld_mx[is_nonzero]
728            sld_my = sld_my[is_nonzero]
729            sld_mz = sld_mz[is_nonzero]
730            pix_symbol = data.pix_symbol[is_nonzero]
731        # II. Plot selective points in color
732        other_color = numpy.ones(len(pix_symbol), dtype='bool')
733        for key in color_dic.keys():
734            chosen_color = pix_symbol == key
735            if numpy.any(chosen_color):
736                other_color = other_color & (chosen_color!=True)
737                color = color_dic[key]
738                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
739                         pos_y[chosen_color], marker, c=color, alpha=0.5,
740                         markeredgecolor=color, markersize=m_size, label=key)
741        # III. Plot All others
742        if numpy.any(other_color):
743            a_name = ''
744            if data.pix_type == 'atom':
745                # Get atom names not in the list
746                a_names = [symb for symb in pix_symbol \
747                           if symb not in color_dic.keys()]
748                a_name = a_names[0]
749                for name in a_names:
750                    new_name = ", " + name
751                    if a_name.count(name) == 0:
752                        a_name += new_name
753            # plot in black
754            im = ax.plot(pos_x[other_color], pos_z[other_color],
755                         pos_y[other_color], marker, c="k", alpha=0.5,
756                         markeredgecolor="k", markersize=m_size, label=a_name)
757        if data.pix_type == 'atom':
758            ax.legend(loc='upper left', prop={'size': 10})
759        # IV. Draws atomic bond with grey lines if any
760        if data.has_conect:
761            for ind in range(len(data.line_x)):
762                im = ax.plot(data.line_x[ind], data.line_z[ind],
763                             data.line_y[ind], '-', lw=0.6, c="grey",
764                             alpha=0.3)
765        # V. Draws magnetic vectors
766        if has_arrow and len(pos_x) > 0:
767            def _draw_arrow(input=None, update=None):
768                """
769                draw magnetic vectors w/arrow
770                """
771                max_mx = max(numpy.fabs(sld_mx))
772                max_my = max(numpy.fabs(sld_my))
773                max_mz = max(numpy.fabs(sld_mz))
774                max_m = max(max_mx, max_my, max_mz)
775                try:
776                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
777                except:
778                    max_step = 0
779                if max_step <= 0:
780                    max_step = 5
781                try:
782                    if max_m != 0:
783                        unit_x2 = sld_mx / max_m
784                        unit_y2 = sld_my / max_m
785                        unit_z2 = sld_mz / max_m
786                        # 0.8 is for avoiding the color becomes white=(1,1,1))
787                        color_x = numpy.fabs(unit_x2 * 0.8)
788                        color_y = numpy.fabs(unit_y2 * 0.8)
789                        color_z = numpy.fabs(unit_z2 * 0.8)
790                        x2 = pos_x + unit_x2 * max_step
791                        y2 = pos_y + unit_y2 * max_step
792                        z2 = pos_z + unit_z2 * max_step
793                        x_arrow = numpy.column_stack((pos_x, x2))
794                        y_arrow = numpy.column_stack((pos_y, y2))
795                        z_arrow = numpy.column_stack((pos_z, z2))
796                        colors = numpy.column_stack((color_x, color_y, color_z))
797                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
798                                        colors, mutation_scale=10, lw=1,
799                                        arrowstyle="->", alpha=0.5)
800                        ax.add_artist(arrows)
801                except:
802                    pass
803                log_msg = "Arrow Drawing completed.\n"
804                logging.info(log_msg)
805            log_msg = "Arrow Drawing is in progress..."
806            logging.info(log_msg)
807
808            # Defer the drawing of arrows to another thread
809            d = threads.deferToThread(_draw_arrow, ax)
810
811        self.figure.canvas.resizing = False
812        self.figure.canvas.draw()
813
814
815class Plotter3D(QtGui.QDialog, Plotter3DWidget):
816    def __init__(self, parent=None, graph_title=''):
817        self.graph_title = graph_title
818        QtGui.QDialog.__init__(self)
819        Plotter3DWidget.__init__(self, manager=parent)
820        self.setWindowTitle(self.graph_title)
821
Note: See TracBrowser for help on using the repository browser.