source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 464cd07

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 464cd07 was 464cd07, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Use singleton QApplication in unit tests to avoid issues on Ubuntu. SASVIEW-485

  • Property mode set to 100644
File size: 33.4 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10from mpl_toolkits.mplot3d import Axes3D
11
12import sas.qtgui.Utilities.GuiUtils as GuiUtils
13
14from sas.qtgui.Utilities.GenericReader import GenReader
15
16from sas.sascalc.dataloader.data_info import Detector
17from sas.sascalc.dataloader.data_info import Source
18from sas.sascalc.calculator import sas_gen
19
20from sas.qtgui.Plotting.Arrow3D import Arrow3D
21from sas.qtgui.Plotting.PlotterBase import PlotterBase
22from sas.qtgui.Plotting.Plotter2D import Plotter2D
23from sas.qtgui.Plotting.Plotter import Plotter
24from sas.qtgui.Plotting.PlotterData import Data1D
25from sas.qtgui.Plotting.PlotterData import Data2D
26
27# Local UI
28from UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
29
30_Q1D_MIN = 0.001
31
32
33class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
34
35    trigger_plot_3d = QtCore.pyqtSignal()
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        self.manager = parent
41        self.communicator = self.manager.communicator()
42        self.model = sas_gen.GenSAS()
43        self.omf_reader = sas_gen.OMFReader()
44        self.sld_reader = sas_gen.SLDReader()
45        self.pdb_reader = sas_gen.PDBReader()
46        self.reader = None
47        self.sld_data = None
48
49        self.parameters = []
50        self.data = None
51        self.datafile = None
52        self.file_name = ''
53        self.ext = None
54        self.default_shape = str(self.cbShape.currentText())
55        self.is_avg = False
56        self.data_to_plot = None
57
58        # combox box
59        self.cbOptionsCalc.setVisible(False)
60
61        # push buttons
62        self.cmdClose.clicked.connect(self.accept)
63        self.cmdHelp.clicked.connect(self.onHelp)
64
65        self.cmdLoad.clicked.connect(self.loadFile)
66        self.cmdCompute.clicked.connect(self.onCompute)
67        self.cmdReset.clicked.connect(self.onReset)
68        self.cmdSave.clicked.connect(self.onSaveFile)
69
70        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
71        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
72
73        # validators
74        # scale, volume and background must be positive
75        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
76        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
77                                                          self.txtScale))
78        self.txtBackground.setValidator(QtGui.QRegExpValidator(
79            validat_regex_pos, self.txtBackground))
80        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
81            validat_regex_pos, self.txtTotalVolume))
82
83        # fraction of spin up between 0 and 1
84        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
85        self.txtUpFracIn.setValidator(
86            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
87        self.txtUpFracOut.setValidator(
88            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
89
90        # 0 < Qmax <= 1000
91        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
92        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
93                                                          self.txtQxMax))
94        self.txtQxMax.textChanged.connect(self.check_value)
95
96        # 2 <= Qbin <= 1000
97        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
98                                                            self.txtNoQBins))
99        self.txtNoQBins.textChanged.connect(self.check_value)
100
101        # plots - 3D in real space
102        plot3d = self.plot3d(has_arrow=False)
103        self.trigger_plot_3d.connect(plot3d)
104        #self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
105
106        self.graph_num = 1  # index for name of graph
107
108        # TODO the option Ellipsoid has not been implemented
109        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
110
111        # New font to display angstrom symbol
112        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
113        self.lblUnitSolventSLD.setStyleSheet(new_font)
114        self.lblUnitVolume.setStyleSheet(new_font)
115        self.lbl5.setStyleSheet(new_font)
116        self.lblUnitMx.setStyleSheet(new_font)
117        self.lblUnitMy.setStyleSheet(new_font)
118        self.lblUnitMz.setStyleSheet(new_font)
119        self.lblUnitNucl.setStyleSheet(new_font)
120        self.lblUnitx.setStyleSheet(new_font)
121        self.lblUnity.setStyleSheet(new_font)
122        self.lblUnitz.setStyleSheet(new_font)
123
124    def selectedshapechange(self):
125        """
126        TODO Temporary solution to display information about option 'Ellipsoid'
127        """
128        print "The option Ellipsoid has not been implemented yet."
129        self.communicator.statusBarUpdateSignal.emit(
130            "The option Ellipsoid has not been implemented yet.")
131
132    def loadFile(self):
133        """
134        Open menu to choose the datafile to load
135        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
136        """
137        try:
138            self.datafile = QtGui.QFileDialog.getOpenFileName(
139                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
140                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
141                                          "OMF files (*.OMF *.omf);; "
142                                          "All files (*.*)")
143            if self.datafile:
144                self.default_shape = str(self.cbShape.currentText())
145                self.file_name = os.path.basename(str(self.datafile))
146                self.ext = os.path.splitext(str(self.datafile))[1]
147                if self.ext in self.omf_reader.ext:
148                    loader = self.omf_reader
149                elif self.ext in self.sld_reader.ext:
150                    loader = self.sld_reader
151                elif self.ext in self.pdb_reader.ext:
152                    loader = self.pdb_reader
153                else:
154                    loader = None
155                # disable some entries depending on type of loaded file
156                # (according to documentation)
157                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
158                    self.txtUpFracIn.setEnabled(False)
159                    self.txtUpFracOut.setEnabled(False)
160                    self.txtUpTheta.setEnabled(False)
161
162                if self.reader is not None and self.reader.isrunning():
163                    self.reader.stop()
164                self.cmdLoad.setEnabled(False)
165                self.cmdLoad.setText('Loading...')
166                self.communicator.statusBarUpdateSignal.emit(
167                    "Loading File {}".format(os.path.basename(
168                        str(self.datafile))))
169                self.reader = GenReader(path=str(self.datafile), loader=loader,
170                                        completefn=self.complete_loading,
171                                        updatefn=self.load_update)
172                self.reader.queue()
173        except (RuntimeError, IOError):
174            log_msg = "Generic SAS Calculator: %s" % sys.exc_value
175            logging.info(log_msg)
176            raise
177        return
178
179    def load_update(self):
180        """ Legacy function used in GenRead """
181        if self.reader.isrunning():
182            status_type = "progress"
183        else:
184            status_type = "stop"
185        logging.info(status_type)
186
187    def complete_loading(self, data=None):
188        """ Function used in GenRead"""
189        self.cbShape.setEnabled(False)
190        try:
191            is_pdbdata = False
192            self.txtData.setText(os.path.basename(str(self.datafile)))
193            self.is_avg = False
194            if self.ext in self.omf_reader.ext:
195                gen = sas_gen.OMF2SLD()
196                gen.set_data(data)
197                self.sld_data = gen.get_magsld()
198                self.check_units()
199            elif self.ext in self.sld_reader.ext:
200                self.sld_data = data
201            elif self.ext in self.pdb_reader.ext:
202                self.sld_data = data
203                is_pdbdata = True
204            # Display combobox of orientation only for pdb data
205            self.cbOptionsCalc.setVisible(is_pdbdata)
206            self.update_gui()
207        except IOError:
208            log_msg = "Loading Error: " \
209                      "This file format is not supported for GenSAS."
210            logging.info(log_msg)
211            raise
212        except ValueError:
213            log_msg = "Could not find any data"
214            logging.info(log_msg)
215            raise
216        logging.info("Load Complete")
217        self.cmdLoad.setEnabled(True)
218        self.cmdLoad.setText('Load')
219        self.trigger_plot_3d.emit()
220
221    def check_units(self):
222        """
223        Check if the units from the OMF file correspond to the default ones
224        displayed on the interface.
225        If not, modify the GUI with the correct unit
226        """
227        #  TODO: adopt the convention of font and symbol for the updated values
228        if sas_gen.OMFData().valueunit != 'A^(-2)':
229            value_unit = sas_gen.OMFData().valueunit
230            self.lbl_unitMx.setText(value_unit)
231            self.lbl_unitMy.setText(value_unit)
232            self.lbl_unitMz.setText(value_unit)
233            self.lbl_unitNucl.setText(value_unit)
234        if sas_gen.OMFData().meshunit != 'A':
235            mesh_unit = sas_gen.OMFData().meshunit
236            self.lbl_unitx.setText(mesh_unit)
237            self.lbl_unity.setText(mesh_unit)
238            self.lbl_unitz.setText(mesh_unit)
239            self.lbl_unitVolume.setText(mesh_unit+"^3")
240
241    def check_value(self):
242        """Check range of text edits for QMax and Number of Qbins """
243        text_edit = self.sender()
244        text_edit.setStyleSheet(
245            QtCore.QString.fromUtf8('background-color: rgb(255, 255, 255);'))
246        if text_edit.text():
247            value = float(str(text_edit.text()))
248            if text_edit == self.txtQxMax:
249                if value <= 0 or value > 1000:
250                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
251                        'background-color: rgb(255, 182, 193);'))
252                else:
253                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
254                        'background-color: rgb(255, 255, 255);'))
255            elif text_edit == self.txtNoQBins:
256                if value < 2 or value > 1000:
257                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
258                        'background-color: rgb(255, 182, 193);'))
259                else:
260                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
261                        'background-color: rgb(255, 255, 255);'))
262
263    def update_gui(self):
264        """ Update the interface with values from loaded data """
265        self.model.set_is_avg(self.is_avg)
266        self.model.set_sld_data(self.sld_data)
267        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
268
269        # add condition for activation of save button
270        self.cmdSave.setEnabled(True)
271
272        # activation of 3D plots' buttons (with and without arrows)
273        self.cmdDraw.setEnabled(self.sld_data is not None)
274        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
275
276        self.txtScale.setText(str(self.model.params['scale']))
277        self.txtBackground.setText(str(self.model.params['background']))
278        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
279
280        # Volume to write to interface: npts x volume of first pixel
281        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
282
283        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
284        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
285        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
286
287        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
288        self.txtNoPixels.setEnabled(False)
289
290        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
291                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
292                           'zstepsize']
293        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
294                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
295                           self.txtXstepsize, self.txtYstepsize,
296                           self.txtZstepsize]
297
298        # Fill right hand side of GUI
299        for indx, item in enumerate(list_parameters):
300            if getattr(self.sld_data, item) is None:
301                list_gui_button[indx].setText('NaN')
302            else:
303                value = getattr(self.sld_data, item)
304                if isinstance(value, numpy.ndarray):
305                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
306                else:
307                    item_for_gui = str(GuiUtils.formatNumber(value, True))
308                list_gui_button[indx].setText(item_for_gui)
309
310        # Enable / disable editing of right hand side of GUI
311        for indx, item in enumerate(list_parameters):
312            if indx < 4:
313                # this condition only applies to Mx,y,z and Nucl
314                value = getattr(self.sld_data, item)
315                enable = self.sld_data.pix_type == 'pixel' \
316                         and numpy.min(value) == numpy.max(value)
317            else:
318                enable = not self.sld_data.is_data
319            list_gui_button[indx].setEnabled(enable)
320
321    def write_new_values_from_gui(self):
322        """
323        update parameters using modified inputs from GUI
324        used before computing
325        """
326        if self.txtScale.isModified():
327            self.model.params['scale'] = float(self.txtScale.text())
328
329        if self.txtBackground.isModified():
330            self.model.params['background'] = float(self.txtBackground.text())
331
332        if self.txtSolventSLD.isModified():
333            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
334
335        # Different condition for total volume to get correct volume after
336        # applying set_sld_data in compute
337        if self.txtTotalVolume.isModified() \
338                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
339            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
340
341        if self.txtUpFracIn.isModified():
342            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
343
344        if self.txtUpFracOut.isModified():
345            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
346
347        if self.txtUpTheta.isModified():
348            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
349
350        if self.txtMx.isModified():
351            self.sld_data.sld_mx = float(self.txtMx.text())*\
352                                   numpy.ones(len(self.sld_data.sld_mx))
353
354        if self.txtMy.isModified():
355            self.sld_data.sld_my = float(self.txtMy.text())*\
356                                   numpy.ones(len(self.sld_data.sld_my))
357
358        if self.txtMz.isModified():
359            self.sld_data.sld_mz = float(self.txtMz.text())*\
360                                   numpy.ones(len(self.sld_data.sld_mz))
361
362        if self.txtNucl.isModified():
363            self.sld_data.sld_n = float(self.txtNucl.text())*\
364                                  numpy.ones(len(self.sld_data.sld_n))
365
366        if self.txtXnodes.isModified():
367            self.sld_data.xnodes = int(self.txtXnodes.text())
368
369        if self.txtYnodes.isModified():
370            self.sld_data.ynodes = int(self.txtYnodes.text())
371
372        if self.txtZnodes.isModified():
373            self.sld_data.znodes = int(self.txtZnodes.text())
374
375        if self.txtXstepsize.isModified():
376            self.sld_data.xstepsize = float(self.txtXstepsize.text())
377
378        if self.txtYstepsize.isModified():
379            self.sld_data.ystepsize = float(self.txtYstepsize.text())
380
381        if self.txtZstepsize.isModified():
382            self.sld_data.zstepsize = float(self.txtZstepsize.text())
383
384        if self.cbOptionsCalc.isVisible():
385            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
386
387    def onHelp(self):
388        """
389        Bring up the Generic Scattering calculator Documentation whenever
390        the HELP button is clicked.
391        Calls Documentation Window with the path of the location within the
392        documentation tree (after /doc/ ....".
393        """
394        try:
395            location = GuiUtils.HELP_DIRECTORY_LOCATION + \
396                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
397            self.manager._helpView.load(QtCore.QUrl(location))
398            self.manager._helpView.show()
399        except AttributeError:
400            # No manager defined - testing and standalone runs
401            pass
402
403    def onReset(self):
404        """ Reset the inputs of textEdit to default values """
405        try:
406            # reset values in textedits
407            self.txtUpFracIn.setText("1.0")
408            self.txtUpFracOut.setText("1.0")
409            self.txtUpTheta.setText("0.0")
410            self.txtBackground.setText("0.0")
411            self.txtScale.setText("1.0")
412            self.txtSolventSLD.setText("0.0")
413            self.txtTotalVolume.setText("216000.0")
414            self.txtNoQBins.setText("50")
415            self.txtQxMax.setText("0.3")
416            self.txtNoPixels.setText("1000")
417            self.txtMx.setText("0")
418            self.txtMy.setText("0")
419            self.txtMz.setText("0")
420            self.txtNucl.setText("6.97e-06")
421            self.txtXnodes.setText("10")
422            self.txtYnodes.setText("10")
423            self.txtZnodes.setText("10")
424            self.txtXstepsize.setText("6")
425            self.txtYstepsize.setText("6")
426            self.txtZstepsize.setText("6")
427            # reset Load button and textedit
428            self.txtData.setText('Default SLD Profile')
429            self.cmdLoad.setEnabled(True)
430            self.cmdLoad.setText('Load')
431            # reset option for calculation
432            self.cbOptionsCalc.setCurrentIndex(0)
433            self.cbOptionsCalc.setVisible(False)
434            # reset shape button
435            self.cbShape.setCurrentIndex(0)
436            self.cbShape.setEnabled(True)
437            # reset compute button
438            self.cmdCompute.setText('Compute')
439            self.cmdCompute.setEnabled(True)
440            # TODO reload default data set
441            self._create_default_sld_data()
442
443        finally:
444            pass
445
446    def _create_default_2d_data(self):
447        """
448        Copied from previous version
449        Create 2D data by default
450        :warning: This data is never plotted.
451        """
452        self.qmax_x = float(self.txtQxMax.text())
453        self.npts_x = int(self.txtNoQBins.text())
454        self.data = Data2D()
455        self.data.is_data = False
456        # # Default values
457        self.data.detector.append(Detector())
458        index = len(self.data.detector) - 1
459        self.data.detector[index].distance = 8000  # mm
460        self.data.source.wavelength = 6  # A
461        self.data.detector[index].pixel_size.x = 5  # mm
462        self.data.detector[index].pixel_size.y = 5  # mm
463        self.data.detector[index].beam_center.x = self.qmax_x
464        self.data.detector[index].beam_center.y = self.qmax_x
465        xmax = self.qmax_x
466        xmin = -xmax
467        ymax = self.qmax_x
468        ymin = -ymax
469        qstep = self.npts_x
470
471        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
472        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
473        # use data info instead
474        new_x = numpy.tile(x, (len(y), 1))
475        new_y = numpy.tile(y, (len(x), 1))
476        new_y = new_y.swapaxes(0, 1)
477        # all data require now in 1d array
478        qx_data = new_x.flatten()
479        qy_data = new_y.flatten()
480        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
481        # set all True (standing for unmasked) as default
482        mask = numpy.ones(len(qx_data), dtype=bool)
483        self.data.source = Source()
484        self.data.data = numpy.ones(len(mask))
485        self.data.err_data = numpy.ones(len(mask))
486        self.data.qx_data = qx_data
487        self.data.qy_data = qy_data
488        self.data.q_data = q_data
489        self.data.mask = mask
490        # store x and y bin centers in q space
491        self.data.x_bins = x
492        self.data.y_bins = y
493        # max and min taking account of the bin sizes
494        self.data.xmin = xmin
495        self.data.xmax = xmax
496        self.data.ymin = ymin
497        self.data.ymax = ymax
498
499    def _create_default_sld_data(self):
500        """
501        Copied from previous version
502        Making default sld-data
503        """
504        sld_n_default = 6.97e-06  # what is this number??
505        omfdata = sas_gen.OMFData()
506        omf2sld = sas_gen.OMF2SLD()
507        omf2sld.set_data(omfdata, self.default_shape)
508        self.sld_data = omf2sld.output
509        self.sld_data.is_data = False
510        self.sld_data.filename = "Default SLD Profile"
511        self.sld_data.set_sldn(sld_n_default)
512
513    def _create_default_1d_data(self):
514        """
515        Copied from previous version
516        Create 1D data by default
517        :warning: This data is never plotted.
518                    residuals.x = data_copy.x[index]
519            residuals.dy = numpy.ones(len(residuals.y))
520            residuals.dx = None
521            residuals.dxl = None
522            residuals.dxw = None
523        """
524        self.qmax_x = float(self.txtQxMax.text())
525        self.npts_x = int(self.txtNoQBins.text())
526        #  Default values
527        xmax = self.qmax_x
528        xmin = self.qmax_x * _Q1D_MIN
529        qstep = self.npts_x
530        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
531        # store x and y bin centers in q space
532        y = numpy.ones(len(x))
533        dy = numpy.zeros(len(x))
534        dx = numpy.zeros(len(x))
535        self.data = Data1D(x=x, y=y)
536        self.data.dx = dx
537        self.data.dy = dy
538
539    def onCompute(self):
540        """
541        Copied from previous version
542        Execute the computation of I(qx, qy)
543        """
544        # Set default data when nothing loaded yet
545        if self.sld_data is None:
546            self._create_default_sld_data()
547        try:
548            self.model.set_sld_data(self.sld_data)
549            self.write_new_values_from_gui()
550            if self.is_avg or self.is_avg is None:
551                self._create_default_1d_data()
552                i_out = numpy.zeros(len(self.data.y))
553                inputs = [self.data.x, [], i_out]
554            else:
555                self._create_default_2d_data()
556                i_out = numpy.zeros(len(self.data.data))
557                inputs = [self.data.qx_data, self.data.qy_data, i_out]
558            logging.info("Computation is in progress...")
559            self.cmdCompute.setText('Wait...')
560            self.cmdCompute.setEnabled(False)
561            d = threads.deferToThread(self.complete, inputs, self._update)
562            # Add deferred callback for call return
563            d.addCallback(self.plot_1_2d)
564        except:
565            log_msg = "{}. stop".format(sys.exc_value)
566            logging.info(log_msg)
567        return
568
569    def _update(self, value):
570        """
571        Copied from previous version
572        """
573        pass
574
575    def complete(self, input, update=None):
576        """
577        Gen compute complete function
578        :Param input: input list [qx_data, qy_data, i_out]
579        """
580        out = numpy.empty(0)
581        for ind in range(len(input[0])):
582            if self.is_avg:
583                if ind % 1 == 0 and update is not None:
584                    # update()
585                    percentage = int(100.0 * float(ind) / len(input[0]))
586                    update(percentage)
587                    time.sleep(0.001)  # 0.1
588                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
589                outi = self.model.run(inputi)
590                out = numpy.append(out, outi)
591            else:
592                if ind % 50 == 0 and update is not None:
593                    percentage = int(100.0 * float(ind) / len(input[0]))
594                    update(percentage)
595                    time.sleep(0.001)
596                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
597                          input[2][ind:ind + 1]]
598                outi = self.model.runXY(inputi)
599                out = numpy.append(out, outi)
600        self.data_to_plot = out
601        logging.info('Gen computation completed.')
602        self.cmdCompute.setText('Compute')
603        self.cmdCompute.setEnabled(True)
604        return
605
606    def onSaveFile(self):
607        """Save data as .sld file"""
608        path = os.path.dirname(str(self.datafile))
609        default_name = os.path.join(path, 'sld_file')
610        kwargs = {
611            'parent': self,
612            'directory': default_name,
613            'filter': 'SLD file (*.sld)',
614            'options': QtGui.QFileDialog.DontUseNativeDialog}
615        # Query user for filename.
616        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
617        if filename:
618            try:
619                if os.path.splitext(filename)[1].lower() == '.sld':
620                    sas_gen.SLDReader().write(filename, self.sld_data)
621                else:
622                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
623                                              self.sld_data)
624            except:
625                raise
626
627    def plot3d(self, has_arrow=False):
628        """ Generate 3D plot in real space with or without arrows """
629        self.write_new_values_from_gui()
630        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
631                                                            self.file_name)
632        if has_arrow:
633            graph_title += ' - Magnetic Vector as Arrow'
634
635        plot3D = Plotter3D(self, graph_title)
636        plot3D.plot(self.sld_data, has_arrow=has_arrow)
637        plot3D.show()
638        self.graph_num += 1
639
640    def plot_1_2d(self, d):
641        """ Generate 1D or 2D plot, called in Compute"""
642        if self.is_avg or self.is_avg is None:
643            data = Data1D(x=self.data.x, y=self.data_to_plot)
644            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
645                                                    int(self.graph_num))
646            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
647            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
648            plot1D = Plotter(self)
649            plot1D.plot(data)
650            plot1D.show()
651            self.graph_num += 1
652            # TODO
653            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
654            return plot1D
655        else:
656            numpy.nan_to_num(self.data_to_plot)
657            data = Data2D(image=self.data_to_plot,
658                          qx_data=self.data.qx_data,
659                          qy_data=self.data.qy_data,
660                          q_data=self.data.q_data,
661                          xmin=self.data.xmin, xmax=self.data.ymax,
662                          ymin=self.data.ymin, ymax=self.data.ymax,
663                          err_image=self.data.err_data)
664            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
665                                                    int(self.graph_num))
666            plot2D = Plotter2D(self)
667            plot2D.plot(data)
668            plot2D.show()
669            self.graph_num += 1
670            # TODO
671            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
672            return plot2D
673
674
675class Plotter3DWidget(PlotterBase):
676    """
677    3D Plot widget for use with a QDialog
678    """
679    def __init__(self, parent=None, manager=None):
680        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
681
682    @property
683    def data(self):
684        return self._data
685
686    @data.setter
687    def data(self, data=None):
688        """ data setter """
689        self._data = data
690
691    def plot(self, data=None, has_arrow=False):
692        """
693        Plot 3D self._data
694        """
695        self.data = data
696        assert(self._data)
697        # Prepare and show the plot
698        self.showPlot(data=self.data, has_arrow=has_arrow)
699
700    def showPlot(self, data, has_arrow=False):
701        """
702        Render and show the current data
703        """
704        # If we don't have any data, skip.
705        if data is None:
706            return
707        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
708                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
709        marker = ','
710        m_size = 2
711
712        pos_x = data.pos_x
713        pos_y = data.pos_y
714        pos_z = data.pos_z
715        sld_mx = data.sld_mx
716        sld_my = data.sld_my
717        sld_mz = data.sld_mz
718        pix_symbol = data.pix_symbol
719        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
720                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
721        is_nonzero = sld_tot > 0.0
722        is_zero = sld_tot == 0.0
723
724        if data.pix_type == 'atom':
725            marker = 'o'
726            m_size = 3.5
727
728        self.figure.clear()
729        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
730        ax = Axes3D(self.figure)
731        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
732        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
733        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
734
735        # I. Plot null points
736        if is_zero.any():
737            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
738                           marker, c="y", alpha=0.5, markeredgecolor='y',
739                           markersize=m_size)
740            pos_x = pos_x[is_nonzero]
741            pos_y = pos_y[is_nonzero]
742            pos_z = pos_z[is_nonzero]
743            sld_mx = sld_mx[is_nonzero]
744            sld_my = sld_my[is_nonzero]
745            sld_mz = sld_mz[is_nonzero]
746            pix_symbol = data.pix_symbol[is_nonzero]
747        # II. Plot selective points in color
748        other_color = numpy.ones(len(pix_symbol), dtype='bool')
749        for key in color_dic.keys():
750            chosen_color = pix_symbol == key
751            if numpy.any(chosen_color):
752                other_color = other_color & (chosen_color!=True)
753                color = color_dic[key]
754                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
755                         pos_y[chosen_color], marker, c=color, alpha=0.5,
756                         markeredgecolor=color, markersize=m_size, label=key)
757        # III. Plot All others
758        if numpy.any(other_color):
759            a_name = ''
760            if data.pix_type == 'atom':
761                # Get atom names not in the list
762                a_names = [symb for symb in pix_symbol \
763                           if symb not in color_dic.keys()]
764                a_name = a_names[0]
765                for name in a_names:
766                    new_name = ", " + name
767                    if a_name.count(name) == 0:
768                        a_name += new_name
769            # plot in black
770            im = ax.plot(pos_x[other_color], pos_z[other_color],
771                         pos_y[other_color], marker, c="k", alpha=0.5,
772                         markeredgecolor="k", markersize=m_size, label=a_name)
773        if data.pix_type == 'atom':
774            ax.legend(loc='upper left', prop={'size': 10})
775        # IV. Draws atomic bond with grey lines if any
776        if data.has_conect:
777            for ind in range(len(data.line_x)):
778                im = ax.plot(data.line_x[ind], data.line_z[ind],
779                             data.line_y[ind], '-', lw=0.6, c="grey",
780                             alpha=0.3)
781        # V. Draws magnetic vectors
782        if has_arrow and len(pos_x) > 0:
783            def _draw_arrow(input=None, update=None):
784                """
785                draw magnetic vectors w/arrow
786                """
787                max_mx = max(numpy.fabs(sld_mx))
788                max_my = max(numpy.fabs(sld_my))
789                max_mz = max(numpy.fabs(sld_mz))
790                max_m = max(max_mx, max_my, max_mz)
791                try:
792                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
793                except:
794                    max_step = 0
795                if max_step <= 0:
796                    max_step = 5
797                try:
798                    if max_m != 0:
799                        unit_x2 = sld_mx / max_m
800                        unit_y2 = sld_my / max_m
801                        unit_z2 = sld_mz / max_m
802                        # 0.8 is for avoiding the color becomes white=(1,1,1))
803                        color_x = numpy.fabs(unit_x2 * 0.8)
804                        color_y = numpy.fabs(unit_y2 * 0.8)
805                        color_z = numpy.fabs(unit_z2 * 0.8)
806                        x2 = pos_x + unit_x2 * max_step
807                        y2 = pos_y + unit_y2 * max_step
808                        z2 = pos_z + unit_z2 * max_step
809                        x_arrow = numpy.column_stack((pos_x, x2))
810                        y_arrow = numpy.column_stack((pos_y, y2))
811                        z_arrow = numpy.column_stack((pos_z, z2))
812                        colors = numpy.column_stack((color_x, color_y, color_z))
813                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
814                                        colors, mutation_scale=10, lw=1,
815                                        arrowstyle="->", alpha=0.5)
816                        ax.add_artist(arrows)
817                except:
818                    pass
819                log_msg = "Arrow Drawing completed.\n"
820                logging.info(log_msg)
821            log_msg = "Arrow Drawing is in progress..."
822            logging.info(log_msg)
823
824            # Defer the drawing of arrows to another thread
825            d = threads.deferToThread(_draw_arrow, ax)
826
827        self.figure.canvas.resizing = False
828        self.figure.canvas.draw()
829
830
831class Plotter3D(QtGui.QDialog, Plotter3DWidget):
832    def __init__(self, parent=None, graph_title=''):
833        self.graph_title = graph_title
834        QtGui.QDialog.__init__(self)
835        Plotter3DWidget.__init__(self, manager=parent)
836        self.setWindowTitle(self.graph_title)
837
Note: See TracBrowser for help on using the repository browser.