source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 57ad773

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 57ad773 was fef38e8, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Startup time improvements - hiding expensive imports and such

  • Property mode set to 100644
File size: 33.5 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10
11import sas.qtgui.Utilities.GuiUtils as GuiUtils
12from sas.qtgui.Utilities.GenericReader import GenReader
13from sas.sascalc.dataloader.data_info import Detector
14from sas.sascalc.dataloader.data_info import Source
15from sas.sascalc.calculator import sas_gen
16from sas.qtgui.Plotting.PlotterBase import PlotterBase
17from sas.qtgui.Plotting.Plotter2D import Plotter2D
18from sas.qtgui.Plotting.Plotter import Plotter
19
20from sas.qtgui.Plotting.PlotterData import Data1D
21from sas.qtgui.Plotting.PlotterData import Data2D
22
23# Local UI
24from UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
25
26_Q1D_MIN = 0.001
27
28
29class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
30
31    trigger_plot_3d = QtCore.pyqtSignal()
32
33    def __init__(self, parent=None):
34        super(GenericScatteringCalculator, self).__init__()
35        self.setupUi(self)
36        self.manager = parent
37        self.communicator = self.manager.communicator()
38        self.model = sas_gen.GenSAS()
39        self.omf_reader = sas_gen.OMFReader()
40        self.sld_reader = sas_gen.SLDReader()
41        self.pdb_reader = sas_gen.PDBReader()
42        self.reader = None
43        self.sld_data = None
44
45        self.parameters = []
46        self.data = None
47        self.datafile = None
48        self.file_name = ''
49        self.ext = None
50        self.default_shape = str(self.cbShape.currentText())
51        self.is_avg = False
52        self.data_to_plot = None
53        self.graph_num = 1  # index for name of graph
54
55        # combox box
56        self.cbOptionsCalc.setVisible(False)
57
58        # push buttons
59        self.cmdClose.clicked.connect(self.accept)
60        self.cmdHelp.clicked.connect(self.onHelp)
61
62        self.cmdLoad.clicked.connect(self.loadFile)
63        self.cmdCompute.clicked.connect(self.onCompute)
64        self.cmdReset.clicked.connect(self.onReset)
65        self.cmdSave.clicked.connect(self.onSaveFile)
66
67        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
68        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
69
70        # validators
71        # scale, volume and background must be positive
72        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
73        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
74                                                          self.txtScale))
75        self.txtBackground.setValidator(QtGui.QRegExpValidator(
76            validat_regex_pos, self.txtBackground))
77        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
78            validat_regex_pos, self.txtTotalVolume))
79
80        # fraction of spin up between 0 and 1
81        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
82        self.txtUpFracIn.setValidator(
83            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
84        self.txtUpFracOut.setValidator(
85            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
86
87        # 0 < Qmax <= 1000
88        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
89        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
90                                                          self.txtQxMax))
91        self.txtQxMax.textChanged.connect(self.check_value)
92
93        # 2 <= Qbin <= 1000
94        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
95                                                            self.txtNoQBins))
96        self.txtNoQBins.textChanged.connect(self.check_value)
97
98        # plots - 3D in real space
99        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
100
101        # TODO the option Ellipsoid has not been implemented
102        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
103
104        # New font to display angstrom symbol
105        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
106        self.lblUnitSolventSLD.setStyleSheet(new_font)
107        self.lblUnitVolume.setStyleSheet(new_font)
108        self.lbl5.setStyleSheet(new_font)
109        self.lblUnitMx.setStyleSheet(new_font)
110        self.lblUnitMy.setStyleSheet(new_font)
111        self.lblUnitMz.setStyleSheet(new_font)
112        self.lblUnitNucl.setStyleSheet(new_font)
113        self.lblUnitx.setStyleSheet(new_font)
114        self.lblUnity.setStyleSheet(new_font)
115        self.lblUnitz.setStyleSheet(new_font)
116
117    def selectedshapechange(self):
118        """
119        TODO Temporary solution to display information about option 'Ellipsoid'
120        """
121        print "The option Ellipsoid has not been implemented yet."
122        self.communicator.statusBarUpdateSignal.emit(
123            "The option Ellipsoid has not been implemented yet.")
124
125    def loadFile(self):
126        """
127        Open menu to choose the datafile to load
128        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
129        """
130        try:
131            self.datafile = QtGui.QFileDialog.getOpenFileName(
132                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
133                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
134                                          "OMF files (*.OMF *.omf);; "
135                                          "All files (*.*)")
136            if self.datafile:
137                self.default_shape = str(self.cbShape.currentText())
138                self.file_name = os.path.basename(str(self.datafile))
139                self.ext = os.path.splitext(str(self.datafile))[1]
140                if self.ext in self.omf_reader.ext:
141                    loader = self.omf_reader
142                elif self.ext in self.sld_reader.ext:
143                    loader = self.sld_reader
144                elif self.ext in self.pdb_reader.ext:
145                    loader = self.pdb_reader
146                else:
147                    loader = None
148                # disable some entries depending on type of loaded file
149                # (according to documentation)
150                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
151                    self.txtUpFracIn.setEnabled(False)
152                    self.txtUpFracOut.setEnabled(False)
153                    self.txtUpTheta.setEnabled(False)
154
155                if self.reader is not None and self.reader.isrunning():
156                    self.reader.stop()
157                self.cmdLoad.setEnabled(False)
158                self.cmdLoad.setText('Loading...')
159                self.communicator.statusBarUpdateSignal.emit(
160                    "Loading File {}".format(os.path.basename(
161                        str(self.datafile))))
162                self.reader = GenReader(path=str(self.datafile), loader=loader,
163                                        completefn=self.complete_loading,
164                                        updatefn=self.load_update)
165                self.reader.queue()
166        except (RuntimeError, IOError):
167            log_msg = "Generic SAS Calculator: %s" % sys.exc_value
168            logging.info(log_msg)
169            raise
170        return
171
172    def load_update(self):
173        """ Legacy function used in GenRead """
174        if self.reader.isrunning():
175            status_type = "progress"
176        else:
177            status_type = "stop"
178        logging.info(status_type)
179
180    def complete_loading(self, data=None):
181        """ Function used in GenRead"""
182        self.cbShape.setEnabled(False)
183        try:
184            is_pdbdata = False
185            self.txtData.setText(os.path.basename(str(self.datafile)))
186            self.is_avg = False
187            if self.ext in self.omf_reader.ext:
188                gen = sas_gen.OMF2SLD()
189                gen.set_data(data)
190                self.sld_data = gen.get_magsld()
191                self.check_units()
192            elif self.ext in self.sld_reader.ext:
193                self.sld_data = data
194            elif self.ext in self.pdb_reader.ext:
195                self.sld_data = data
196                is_pdbdata = True
197            # Display combobox of orientation only for pdb data
198            self.cbOptionsCalc.setVisible(is_pdbdata)
199            self.update_gui()
200        except IOError:
201            log_msg = "Loading Error: " \
202                      "This file format is not supported for GenSAS."
203            logging.info(log_msg)
204            raise
205        except ValueError:
206            log_msg = "Could not find any data"
207            logging.info(log_msg)
208            raise
209        logging.info("Load Complete")
210        self.cmdLoad.setEnabled(True)
211        self.cmdLoad.setText('Load')
212        self.trigger_plot_3d.emit()
213
214    def check_units(self):
215        """
216        Check if the units from the OMF file correspond to the default ones
217        displayed on the interface.
218        If not, modify the GUI with the correct unit
219        """
220        #  TODO: adopt the convention of font and symbol for the updated values
221        if sas_gen.OMFData().valueunit != 'A^(-2)':
222            value_unit = sas_gen.OMFData().valueunit
223            self.lbl_unitMx.setText(value_unit)
224            self.lbl_unitMy.setText(value_unit)
225            self.lbl_unitMz.setText(value_unit)
226            self.lbl_unitNucl.setText(value_unit)
227        if sas_gen.OMFData().meshunit != 'A':
228            mesh_unit = sas_gen.OMFData().meshunit
229            self.lbl_unitx.setText(mesh_unit)
230            self.lbl_unity.setText(mesh_unit)
231            self.lbl_unitz.setText(mesh_unit)
232            self.lbl_unitVolume.setText(mesh_unit+"^3")
233
234    def check_value(self):
235        """Check range of text edits for QMax and Number of Qbins """
236        text_edit = self.sender()
237        text_edit.setStyleSheet(
238            QtCore.QString.fromUtf8('background-color: rgb(255, 255, 255);'))
239        if text_edit.text():
240            value = float(str(text_edit.text()))
241            if text_edit == self.txtQxMax:
242                if value <= 0 or value > 1000:
243                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
244                        'background-color: rgb(255, 182, 193);'))
245                else:
246                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
247                        'background-color: rgb(255, 255, 255);'))
248            elif text_edit == self.txtNoQBins:
249                if value < 2 or value > 1000:
250                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
251                        'background-color: rgb(255, 182, 193);'))
252                else:
253                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
254                        'background-color: rgb(255, 255, 255);'))
255
256    def update_gui(self):
257        """ Update the interface with values from loaded data """
258        self.model.set_is_avg(self.is_avg)
259        self.model.set_sld_data(self.sld_data)
260        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
261
262        # add condition for activation of save button
263        self.cmdSave.setEnabled(True)
264
265        # activation of 3D plots' buttons (with and without arrows)
266        self.cmdDraw.setEnabled(self.sld_data is not None)
267        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
268
269        self.txtScale.setText(str(self.model.params['scale']))
270        self.txtBackground.setText(str(self.model.params['background']))
271        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
272
273        # Volume to write to interface: npts x volume of first pixel
274        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
275
276        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
277        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
278        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
279
280        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
281        self.txtNoPixels.setEnabled(False)
282
283        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
284                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
285                           'zstepsize']
286        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
287                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
288                           self.txtXstepsize, self.txtYstepsize,
289                           self.txtZstepsize]
290
291        # Fill right hand side of GUI
292        for indx, item in enumerate(list_parameters):
293            if getattr(self.sld_data, item) is None:
294                list_gui_button[indx].setText('NaN')
295            else:
296                value = getattr(self.sld_data, item)
297                if isinstance(value, numpy.ndarray):
298                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
299                else:
300                    item_for_gui = str(GuiUtils.formatNumber(value, True))
301                list_gui_button[indx].setText(item_for_gui)
302
303        # Enable / disable editing of right hand side of GUI
304        for indx, item in enumerate(list_parameters):
305            if indx < 4:
306                # this condition only applies to Mx,y,z and Nucl
307                value = getattr(self.sld_data, item)
308                enable = self.sld_data.pix_type == 'pixel' \
309                         and numpy.min(value) == numpy.max(value)
310            else:
311                enable = not self.sld_data.is_data
312            list_gui_button[indx].setEnabled(enable)
313
314    def write_new_values_from_gui(self):
315        """
316        update parameters using modified inputs from GUI
317        used before computing
318        """
319        if self.txtScale.isModified():
320            self.model.params['scale'] = float(self.txtScale.text())
321
322        if self.txtBackground.isModified():
323            self.model.params['background'] = float(self.txtBackground.text())
324
325        if self.txtSolventSLD.isModified():
326            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
327
328        # Different condition for total volume to get correct volume after
329        # applying set_sld_data in compute
330        if self.txtTotalVolume.isModified() \
331                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
332            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
333
334        if self.txtUpFracIn.isModified():
335            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
336
337        if self.txtUpFracOut.isModified():
338            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
339
340        if self.txtUpTheta.isModified():
341            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
342
343        if self.txtMx.isModified():
344            self.sld_data.sld_mx = float(self.txtMx.text())*\
345                                   numpy.ones(len(self.sld_data.sld_mx))
346
347        if self.txtMy.isModified():
348            self.sld_data.sld_my = float(self.txtMy.text())*\
349                                   numpy.ones(len(self.sld_data.sld_my))
350
351        if self.txtMz.isModified():
352            self.sld_data.sld_mz = float(self.txtMz.text())*\
353                                   numpy.ones(len(self.sld_data.sld_mz))
354
355        if self.txtNucl.isModified():
356            self.sld_data.sld_n = float(self.txtNucl.text())*\
357                                  numpy.ones(len(self.sld_data.sld_n))
358
359        if self.txtXnodes.isModified():
360            self.sld_data.xnodes = int(self.txtXnodes.text())
361
362        if self.txtYnodes.isModified():
363            self.sld_data.ynodes = int(self.txtYnodes.text())
364
365        if self.txtZnodes.isModified():
366            self.sld_data.znodes = int(self.txtZnodes.text())
367
368        if self.txtXstepsize.isModified():
369            self.sld_data.xstepsize = float(self.txtXstepsize.text())
370
371        if self.txtYstepsize.isModified():
372            self.sld_data.ystepsize = float(self.txtYstepsize.text())
373
374        if self.txtZstepsize.isModified():
375            self.sld_data.zstepsize = float(self.txtZstepsize.text())
376
377        if self.cbOptionsCalc.isVisible():
378            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
379
380    def onHelp(self):
381        """
382        Bring up the Generic Scattering calculator Documentation whenever
383        the HELP button is clicked.
384        Calls Documentation Window with the path of the location within the
385        documentation tree (after /doc/ ....".
386        """
387        try:
388            location = GuiUtils.HELP_DIRECTORY_LOCATION + \
389                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
390            self.manager._helpView.load(QtCore.QUrl(location))
391            self.manager._helpView.show()
392        except AttributeError:
393            # No manager defined - testing and standalone runs
394            pass
395
396    def onReset(self):
397        """ Reset the inputs of textEdit to default values """
398        try:
399            # reset values in textedits
400            self.txtUpFracIn.setText("1.0")
401            self.txtUpFracOut.setText("1.0")
402            self.txtUpTheta.setText("0.0")
403            self.txtBackground.setText("0.0")
404            self.txtScale.setText("1.0")
405            self.txtSolventSLD.setText("0.0")
406            self.txtTotalVolume.setText("216000.0")
407            self.txtNoQBins.setText("50")
408            self.txtQxMax.setText("0.3")
409            self.txtNoPixels.setText("1000")
410            self.txtMx.setText("0")
411            self.txtMy.setText("0")
412            self.txtMz.setText("0")
413            self.txtNucl.setText("6.97e-06")
414            self.txtXnodes.setText("10")
415            self.txtYnodes.setText("10")
416            self.txtZnodes.setText("10")
417            self.txtXstepsize.setText("6")
418            self.txtYstepsize.setText("6")
419            self.txtZstepsize.setText("6")
420            # reset Load button and textedit
421            self.txtData.setText('Default SLD Profile')
422            self.cmdLoad.setEnabled(True)
423            self.cmdLoad.setText('Load')
424            # reset option for calculation
425            self.cbOptionsCalc.setCurrentIndex(0)
426            self.cbOptionsCalc.setVisible(False)
427            # reset shape button
428            self.cbShape.setCurrentIndex(0)
429            self.cbShape.setEnabled(True)
430            # reset compute button
431            self.cmdCompute.setText('Compute')
432            self.cmdCompute.setEnabled(True)
433            # TODO reload default data set
434            self._create_default_sld_data()
435
436        finally:
437            pass
438
439    def _create_default_2d_data(self):
440        """
441        Copied from previous version
442        Create 2D data by default
443        :warning: This data is never plotted.
444        """
445        self.qmax_x = float(self.txtQxMax.text())
446        self.npts_x = int(self.txtNoQBins.text())
447        self.data = Data2D()
448        self.data.is_data = False
449        # # Default values
450        self.data.detector.append(Detector())
451        index = len(self.data.detector) - 1
452        self.data.detector[index].distance = 8000  # mm
453        self.data.source.wavelength = 6  # A
454        self.data.detector[index].pixel_size.x = 5  # mm
455        self.data.detector[index].pixel_size.y = 5  # mm
456        self.data.detector[index].beam_center.x = self.qmax_x
457        self.data.detector[index].beam_center.y = self.qmax_x
458        xmax = self.qmax_x
459        xmin = -xmax
460        ymax = self.qmax_x
461        ymin = -ymax
462        qstep = self.npts_x
463
464        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
465        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
466        # use data info instead
467        new_x = numpy.tile(x, (len(y), 1))
468        new_y = numpy.tile(y, (len(x), 1))
469        new_y = new_y.swapaxes(0, 1)
470        # all data require now in 1d array
471        qx_data = new_x.flatten()
472        qy_data = new_y.flatten()
473        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
474        # set all True (standing for unmasked) as default
475        mask = numpy.ones(len(qx_data), dtype=bool)
476        self.data.source = Source()
477        self.data.data = numpy.ones(len(mask))
478        self.data.err_data = numpy.ones(len(mask))
479        self.data.qx_data = qx_data
480        self.data.qy_data = qy_data
481        self.data.q_data = q_data
482        self.data.mask = mask
483        # store x and y bin centers in q space
484        self.data.x_bins = x
485        self.data.y_bins = y
486        # max and min taking account of the bin sizes
487        self.data.xmin = xmin
488        self.data.xmax = xmax
489        self.data.ymin = ymin
490        self.data.ymax = ymax
491
492    def _create_default_sld_data(self):
493        """
494        Copied from previous version
495        Making default sld-data
496        """
497        sld_n_default = 6.97e-06  # what is this number??
498        omfdata = sas_gen.OMFData()
499        omf2sld = sas_gen.OMF2SLD()
500        omf2sld.set_data(omfdata, self.default_shape)
501        self.sld_data = omf2sld.output
502        self.sld_data.is_data = False
503        self.sld_data.filename = "Default SLD Profile"
504        self.sld_data.set_sldn(sld_n_default)
505
506    def _create_default_1d_data(self):
507        """
508        Copied from previous version
509        Create 1D data by default
510        :warning: This data is never plotted.
511                    residuals.x = data_copy.x[index]
512            residuals.dy = numpy.ones(len(residuals.y))
513            residuals.dx = None
514            residuals.dxl = None
515            residuals.dxw = None
516        """
517        self.qmax_x = float(self.txtQxMax.text())
518        self.npts_x = int(self.txtNoQBins.text())
519        #  Default values
520        xmax = self.qmax_x
521        xmin = self.qmax_x * _Q1D_MIN
522        qstep = self.npts_x
523        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
524        # store x and y bin centers in q space
525        y = numpy.ones(len(x))
526        dy = numpy.zeros(len(x))
527        dx = numpy.zeros(len(x))
528        self.data = Data1D(x=x, y=y)
529        self.data.dx = dx
530        self.data.dy = dy
531
532    def onCompute(self):
533        """
534        Copied from previous version
535        Execute the computation of I(qx, qy)
536        """
537        # Set default data when nothing loaded yet
538        if self.sld_data is None:
539            self._create_default_sld_data()
540        try:
541            self.model.set_sld_data(self.sld_data)
542            self.write_new_values_from_gui()
543            if self.is_avg or self.is_avg is None:
544                self._create_default_1d_data()
545                i_out = numpy.zeros(len(self.data.y))
546                inputs = [self.data.x, [], i_out]
547            else:
548                self._create_default_2d_data()
549                i_out = numpy.zeros(len(self.data.data))
550                inputs = [self.data.qx_data, self.data.qy_data, i_out]
551            logging.info("Computation is in progress...")
552            self.cmdCompute.setText('Wait...')
553            self.cmdCompute.setEnabled(False)
554            d = threads.deferToThread(self.complete, inputs, self._update)
555            # Add deferred callback for call return
556            d.addCallback(self.plot_1_2d)
557        except:
558            log_msg = "{}. stop".format(sys.exc_value)
559            logging.info(log_msg)
560        return
561
562    def _update(self, value):
563        """
564        Copied from previous version
565        """
566        pass
567
568    def complete(self, input, update=None):
569        """
570        Gen compute complete function
571        :Param input: input list [qx_data, qy_data, i_out]
572        """
573        out = numpy.empty(0)
574        for ind in range(len(input[0])):
575            if self.is_avg:
576                if ind % 1 == 0 and update is not None:
577                    # update()
578                    percentage = int(100.0 * float(ind) / len(input[0]))
579                    update(percentage)
580                    time.sleep(0.001)  # 0.1
581                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
582                outi = self.model.run(inputi)
583                out = numpy.append(out, outi)
584            else:
585                if ind % 50 == 0 and update is not None:
586                    percentage = int(100.0 * float(ind) / len(input[0]))
587                    update(percentage)
588                    time.sleep(0.001)
589                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
590                          input[2][ind:ind + 1]]
591                outi = self.model.runXY(inputi)
592                out = numpy.append(out, outi)
593        self.data_to_plot = out
594        logging.info('Gen computation completed.')
595        self.cmdCompute.setText('Compute')
596        self.cmdCompute.setEnabled(True)
597        return
598
599    def onSaveFile(self):
600        """Save data as .sld file"""
601        path = os.path.dirname(str(self.datafile))
602        default_name = os.path.join(path, 'sld_file')
603        kwargs = {
604            'parent': self,
605            'directory': default_name,
606            'filter': 'SLD file (*.sld)',
607            'options': QtGui.QFileDialog.DontUseNativeDialog}
608        # Query user for filename.
609        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
610        if filename:
611            try:
612                if os.path.splitext(filename)[1].lower() == '.sld':
613                    sas_gen.SLDReader().write(filename, self.sld_data)
614                else:
615                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
616                                              self.sld_data)
617            except:
618                raise
619
620    def plot3d(self, has_arrow=False):
621        """ Generate 3D plot in real space with or without arrows """
622        self.write_new_values_from_gui()
623        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
624                                                            self.file_name)
625        if has_arrow:
626            graph_title += ' - Magnetic Vector as Arrow'
627
628        plot3D = Plotter3D(self, graph_title)
629        plot3D.plot(self.sld_data, has_arrow=has_arrow)
630        plot3D.show()
631        self.graph_num += 1
632
633    def plot_1_2d(self, d):
634        """ Generate 1D or 2D plot, called in Compute"""
635        if self.is_avg or self.is_avg is None:
636            data = Data1D(x=self.data.x, y=self.data_to_plot)
637            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
638                                                    int(self.graph_num))
639            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
640            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
641            plot1D = Plotter(self)
642            plot1D.plot(data)
643            plot1D.show()
644            self.graph_num += 1
645            # TODO
646            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
647            return plot1D
648        else:
649            numpy.nan_to_num(self.data_to_plot)
650            data = Data2D(image=self.data_to_plot,
651                          qx_data=self.data.qx_data,
652                          qy_data=self.data.qy_data,
653                          q_data=self.data.q_data,
654                          xmin=self.data.xmin, xmax=self.data.ymax,
655                          ymin=self.data.ymin, ymax=self.data.ymax,
656                          err_image=self.data.err_data)
657            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
658                                                    int(self.graph_num))
659            plot2D = Plotter2D(self)
660            plot2D.plot(data)
661            plot2D.show()
662            self.graph_num += 1
663            # TODO
664            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
665            return plot2D
666
667
668class Plotter3DWidget(PlotterBase):
669    """
670    3D Plot widget for use with a QDialog
671    """
672    def __init__(self, parent=None, manager=None):
673        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
674
675    @property
676    def data(self):
677        return self._data
678
679    @data.setter
680    def data(self, data=None):
681        """ data setter """
682        self._data = data
683
684    def plot(self, data=None, has_arrow=False):
685        """
686        Plot 3D self._data
687        """
688        if not data:
689            return
690        self.data = data
691        #assert(self._data)
692        # Prepare and show the plot
693        self.showPlot(data=self.data, has_arrow=has_arrow)
694
695    def showPlot(self, data, has_arrow=False):
696        """
697        Render and show the current data
698        """
699        # If we don't have any data, skip.
700        if data is None:
701            return
702        # This import takes forever - place it here so the main UI starts faster
703        from mpl_toolkits.mplot3d import Axes3D
704        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
705                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
706        marker = ','
707        m_size = 2
708
709        pos_x = data.pos_x
710        pos_y = data.pos_y
711        pos_z = data.pos_z
712        sld_mx = data.sld_mx
713        sld_my = data.sld_my
714        sld_mz = data.sld_mz
715        pix_symbol = data.pix_symbol
716        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
717                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
718        is_nonzero = sld_tot > 0.0
719        is_zero = sld_tot == 0.0
720
721        if data.pix_type == 'atom':
722            marker = 'o'
723            m_size = 3.5
724
725        self.figure.clear()
726        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
727        ax = Axes3D(self.figure)
728        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
729        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
730        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
731
732        # I. Plot null points
733        if is_zero.any():
734            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
735                           marker, c="y", alpha=0.5, markeredgecolor='y',
736                           markersize=m_size)
737            pos_x = pos_x[is_nonzero]
738            pos_y = pos_y[is_nonzero]
739            pos_z = pos_z[is_nonzero]
740            sld_mx = sld_mx[is_nonzero]
741            sld_my = sld_my[is_nonzero]
742            sld_mz = sld_mz[is_nonzero]
743            pix_symbol = data.pix_symbol[is_nonzero]
744        # II. Plot selective points in color
745        other_color = numpy.ones(len(pix_symbol), dtype='bool')
746        for key in color_dic.keys():
747            chosen_color = pix_symbol == key
748            if numpy.any(chosen_color):
749                other_color = other_color & (chosen_color!=True)
750                color = color_dic[key]
751                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
752                         pos_y[chosen_color], marker, c=color, alpha=0.5,
753                         markeredgecolor=color, markersize=m_size, label=key)
754        # III. Plot All others
755        if numpy.any(other_color):
756            a_name = ''
757            if data.pix_type == 'atom':
758                # Get atom names not in the list
759                a_names = [symb for symb in pix_symbol \
760                           if symb not in color_dic.keys()]
761                a_name = a_names[0]
762                for name in a_names:
763                    new_name = ", " + name
764                    if a_name.count(name) == 0:
765                        a_name += new_name
766            # plot in black
767            im = ax.plot(pos_x[other_color], pos_z[other_color],
768                         pos_y[other_color], marker, c="k", alpha=0.5,
769                         markeredgecolor="k", markersize=m_size, label=a_name)
770        if data.pix_type == 'atom':
771            ax.legend(loc='upper left', prop={'size': 10})
772        # IV. Draws atomic bond with grey lines if any
773        if data.has_conect:
774            for ind in range(len(data.line_x)):
775                im = ax.plot(data.line_x[ind], data.line_z[ind],
776                             data.line_y[ind], '-', lw=0.6, c="grey",
777                             alpha=0.3)
778        # V. Draws magnetic vectors
779        if has_arrow and len(pos_x) > 0:
780            def _draw_arrow(input=None, update=None):
781                # import moved here for performance reasons
782                from sas.qtgui.Plotting.Arrow3D import Arrow3D
783                """
784                draw magnetic vectors w/arrow
785                """
786                max_mx = max(numpy.fabs(sld_mx))
787                max_my = max(numpy.fabs(sld_my))
788                max_mz = max(numpy.fabs(sld_mz))
789                max_m = max(max_mx, max_my, max_mz)
790                try:
791                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
792                except:
793                    max_step = 0
794                if max_step <= 0:
795                    max_step = 5
796                try:
797                    if max_m != 0:
798                        unit_x2 = sld_mx / max_m
799                        unit_y2 = sld_my / max_m
800                        unit_z2 = sld_mz / max_m
801                        # 0.8 is for avoiding the color becomes white=(1,1,1))
802                        color_x = numpy.fabs(unit_x2 * 0.8)
803                        color_y = numpy.fabs(unit_y2 * 0.8)
804                        color_z = numpy.fabs(unit_z2 * 0.8)
805                        x2 = pos_x + unit_x2 * max_step
806                        y2 = pos_y + unit_y2 * max_step
807                        z2 = pos_z + unit_z2 * max_step
808                        x_arrow = numpy.column_stack((pos_x, x2))
809                        y_arrow = numpy.column_stack((pos_y, y2))
810                        z_arrow = numpy.column_stack((pos_z, z2))
811                        colors = numpy.column_stack((color_x, color_y, color_z))
812                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
813                                        colors, mutation_scale=10, lw=1,
814                                        arrowstyle="->", alpha=0.5)
815                        ax.add_artist(arrows)
816                except:
817                    pass
818                log_msg = "Arrow Drawing completed.\n"
819                logging.info(log_msg)
820            log_msg = "Arrow Drawing is in progress..."
821            logging.info(log_msg)
822
823            # Defer the drawing of arrows to another thread
824            d = threads.deferToThread(_draw_arrow, ax)
825
826        self.figure.canvas.resizing = False
827        self.figure.canvas.draw()
828
829
830class Plotter3D(QtGui.QDialog, Plotter3DWidget):
831    def __init__(self, parent=None, graph_title=''):
832        self.graph_title = graph_title
833        QtGui.QDialog.__init__(self)
834        Plotter3DWidget.__init__(self, manager=parent)
835        self.setWindowTitle(self.graph_title)
836
Note: See TracBrowser for help on using the repository browser.