source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 7fb471d

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 7fb471d was 7fb471d, checked in by Piotr Rozyczko <rozyczko@…>, 6 years ago

Update for unit tests and minor functionality quirks

  • Property mode set to 100644
File size: 33.4 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10
11import sas.qtgui.Utilities.GuiUtils as GuiUtils
12from sas.qtgui.Utilities.GenericReader import GenReader
13from sas.sascalc.dataloader.data_info import Detector
14from sas.sascalc.dataloader.data_info import Source
15from sas.sascalc.calculator import sas_gen
16from sas.qtgui.Plotting.PlotterBase import PlotterBase
17from sas.qtgui.Plotting.Plotter2D import Plotter2D
18from sas.qtgui.Plotting.Plotter import Plotter
19
20from sas.qtgui.Plotting.PlotterData import Data1D
21from sas.qtgui.Plotting.PlotterData import Data2D
22
23# Local UI
24from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
25
26_Q1D_MIN = 0.001
27
28
29class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
30
31    trigger_plot_3d = QtCore.pyqtSignal()
32
33    def __init__(self, parent=None):
34        super(GenericScatteringCalculator, self).__init__()
35        self.setupUi(self)
36        self.manager = parent
37        self.communicator = self.manager.communicator()
38        self.model = sas_gen.GenSAS()
39        self.omf_reader = sas_gen.OMFReader()
40        self.sld_reader = sas_gen.SLDReader()
41        self.pdb_reader = sas_gen.PDBReader()
42        self.reader = None
43        self.sld_data = None
44
45        self.parameters = []
46        self.data = None
47        self.datafile = None
48        self.file_name = ''
49        self.ext = None
50        self.default_shape = str(self.cbShape.currentText())
51        self.is_avg = False
52        self.data_to_plot = None
53        self.graph_num = 1  # index for name of graph
54
55        # combox box
56        self.cbOptionsCalc.setVisible(False)
57
58        # push buttons
59        self.cmdClose.clicked.connect(self.accept)
60        self.cmdHelp.clicked.connect(self.onHelp)
61
62        self.cmdLoad.clicked.connect(self.loadFile)
63        self.cmdCompute.clicked.connect(self.onCompute)
64        self.cmdReset.clicked.connect(self.onReset)
65        self.cmdSave.clicked.connect(self.onSaveFile)
66
67        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
68        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
69
70        # validators
71        # scale, volume and background must be positive
72        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
73        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
74                                                          self.txtScale))
75        self.txtBackground.setValidator(QtGui.QRegExpValidator(
76            validat_regex_pos, self.txtBackground))
77        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
78            validat_regex_pos, self.txtTotalVolume))
79
80        # fraction of spin up between 0 and 1
81        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
82        self.txtUpFracIn.setValidator(
83            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
84        self.txtUpFracOut.setValidator(
85            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
86
87        # 0 < Qmax <= 1000
88        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
89        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
90                                                          self.txtQxMax))
91        self.txtQxMax.textChanged.connect(self.check_value)
92
93        # 2 <= Qbin <= 1000
94        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
95                                                            self.txtNoQBins))
96        self.txtNoQBins.textChanged.connect(self.check_value)
97
98        # plots - 3D in real space
99        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
100
101        # TODO the option Ellipsoid has not been implemented
102        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
103
104        # New font to display angstrom symbol
105        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
106        self.lblUnitSolventSLD.setStyleSheet(new_font)
107        self.lblUnitVolume.setStyleSheet(new_font)
108        self.lbl5.setStyleSheet(new_font)
109        self.lblUnitMx.setStyleSheet(new_font)
110        self.lblUnitMy.setStyleSheet(new_font)
111        self.lblUnitMz.setStyleSheet(new_font)
112        self.lblUnitNucl.setStyleSheet(new_font)
113        self.lblUnitx.setStyleSheet(new_font)
114        self.lblUnity.setStyleSheet(new_font)
115        self.lblUnitz.setStyleSheet(new_font)
116
117    def selectedshapechange(self):
118        """
119        TODO Temporary solution to display information about option 'Ellipsoid'
120        """
121        print("The option Ellipsoid has not been implemented yet.")
122        self.communicator.statusBarUpdateSignal.emit(
123            "The option Ellipsoid has not been implemented yet.")
124
125    def loadFile(self):
126        """
127        Open menu to choose the datafile to load
128        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
129        """
130        try:
131            self.datafile = QtGui.QFileDialog.getOpenFileName(
132                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
133                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
134                                          "OMF files (*.OMF *.omf);; "
135                                          "All files (*.*)")
136            if self.datafile:
137                self.default_shape = str(self.cbShape.currentText())
138                self.file_name = os.path.basename(str(self.datafile))
139                self.ext = os.path.splitext(str(self.datafile))[1]
140                if self.ext in self.omf_reader.ext:
141                    loader = self.omf_reader
142                elif self.ext in self.sld_reader.ext:
143                    loader = self.sld_reader
144                elif self.ext in self.pdb_reader.ext:
145                    loader = self.pdb_reader
146                else:
147                    loader = None
148                # disable some entries depending on type of loaded file
149                # (according to documentation)
150                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
151                    self.txtUpFracIn.setEnabled(False)
152                    self.txtUpFracOut.setEnabled(False)
153                    self.txtUpTheta.setEnabled(False)
154
155                if self.reader is not None and self.reader.isrunning():
156                    self.reader.stop()
157                self.cmdLoad.setEnabled(False)
158                self.cmdLoad.setText('Loading...')
159                self.communicator.statusBarUpdateSignal.emit(
160                    "Loading File {}".format(os.path.basename(
161                        str(self.datafile))))
162                self.reader = GenReader(path=str(self.datafile), loader=loader,
163                                        completefn=self.complete_loading,
164                                        updatefn=self.load_update)
165                self.reader.queue()
166        except (RuntimeError, IOError):
167            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
168            logging.info(log_msg)
169            raise
170        return
171
172    def load_update(self):
173        """ Legacy function used in GenRead """
174        if self.reader.isrunning():
175            status_type = "progress"
176        else:
177            status_type = "stop"
178        logging.info(status_type)
179
180    def complete_loading(self, data=None):
181        """ Function used in GenRead"""
182        self.cbShape.setEnabled(False)
183        try:
184            is_pdbdata = False
185            self.txtData.setText(os.path.basename(str(self.datafile)))
186            self.is_avg = False
187            if self.ext in self.omf_reader.ext:
188                gen = sas_gen.OMF2SLD()
189                gen.set_data(data)
190                self.sld_data = gen.get_magsld()
191                self.check_units()
192            elif self.ext in self.sld_reader.ext:
193                self.sld_data = data
194            elif self.ext in self.pdb_reader.ext:
195                self.sld_data = data
196                is_pdbdata = True
197            # Display combobox of orientation only for pdb data
198            self.cbOptionsCalc.setVisible(is_pdbdata)
199            self.update_gui()
200        except IOError:
201            log_msg = "Loading Error: " \
202                      "This file format is not supported for GenSAS."
203            logging.info(log_msg)
204            raise
205        except ValueError:
206            log_msg = "Could not find any data"
207            logging.info(log_msg)
208            raise
209        logging.info("Load Complete")
210        self.cmdLoad.setEnabled(True)
211        self.cmdLoad.setText('Load')
212        self.trigger_plot_3d.emit()
213
214    def check_units(self):
215        """
216        Check if the units from the OMF file correspond to the default ones
217        displayed on the interface.
218        If not, modify the GUI with the correct unit
219        """
220        #  TODO: adopt the convention of font and symbol for the updated values
221        if sas_gen.OMFData().valueunit != 'A^(-2)':
222            value_unit = sas_gen.OMFData().valueunit
223            self.lbl_unitMx.setText(value_unit)
224            self.lbl_unitMy.setText(value_unit)
225            self.lbl_unitMz.setText(value_unit)
226            self.lbl_unitNucl.setText(value_unit)
227        if sas_gen.OMFData().meshunit != 'A':
228            mesh_unit = sas_gen.OMFData().meshunit
229            self.lbl_unitx.setText(mesh_unit)
230            self.lbl_unity.setText(mesh_unit)
231            self.lbl_unitz.setText(mesh_unit)
232            self.lbl_unitVolume.setText(mesh_unit+"^3")
233
234    def check_value(self):
235        """Check range of text edits for QMax and Number of Qbins """
236        text_edit = self.sender()
237        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
238        if text_edit.text():
239            value = float(str(text_edit.text()))
240            if text_edit == self.txtQxMax:
241                if value <= 0 or value > 1000:
242                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
243                else:
244                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
245            elif text_edit == self.txtNoQBins:
246                if value < 2 or value > 1000:
247                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
248                else:
249                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
250
251    def update_gui(self):
252        """ Update the interface with values from loaded data """
253        self.model.set_is_avg(self.is_avg)
254        self.model.set_sld_data(self.sld_data)
255        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
256
257        # add condition for activation of save button
258        self.cmdSave.setEnabled(True)
259
260        # activation of 3D plots' buttons (with and without arrows)
261        self.cmdDraw.setEnabled(self.sld_data is not None)
262        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
263
264        self.txtScale.setText(str(self.model.params['scale']))
265        self.txtBackground.setText(str(self.model.params['background']))
266        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
267
268        # Volume to write to interface: npts x volume of first pixel
269        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
270
271        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
272        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
273        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
274
275        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
276        self.txtNoPixels.setEnabled(False)
277
278        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
279                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
280                           'zstepsize']
281        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
282                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
283                           self.txtXstepsize, self.txtYstepsize,
284                           self.txtZstepsize]
285
286        # Fill right hand side of GUI
287        for indx, item in enumerate(list_parameters):
288            if getattr(self.sld_data, item) is None:
289                list_gui_button[indx].setText('NaN')
290            else:
291                value = getattr(self.sld_data, item)
292                if isinstance(value, numpy.ndarray):
293                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
294                else:
295                    item_for_gui = str(GuiUtils.formatNumber(value, True))
296                list_gui_button[indx].setText(item_for_gui)
297
298        # Enable / disable editing of right hand side of GUI
299        for indx, item in enumerate(list_parameters):
300            if indx < 4:
301                # this condition only applies to Mx,y,z and Nucl
302                value = getattr(self.sld_data, item)
303                enable = self.sld_data.pix_type == 'pixel' \
304                         and numpy.min(value) == numpy.max(value)
305            else:
306                enable = not self.sld_data.is_data
307            list_gui_button[indx].setEnabled(enable)
308
309    def write_new_values_from_gui(self):
310        """
311        update parameters using modified inputs from GUI
312        used before computing
313        """
314        if self.txtScale.isModified():
315            self.model.params['scale'] = float(self.txtScale.text())
316
317        if self.txtBackground.isModified():
318            self.model.params['background'] = float(self.txtBackground.text())
319
320        if self.txtSolventSLD.isModified():
321            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
322
323        # Different condition for total volume to get correct volume after
324        # applying set_sld_data in compute
325        if self.txtTotalVolume.isModified() \
326                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
327            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
328
329        if self.txtUpFracIn.isModified():
330            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
331
332        if self.txtUpFracOut.isModified():
333            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
334
335        if self.txtUpTheta.isModified():
336            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
337
338        if self.txtMx.isModified():
339            self.sld_data.sld_mx = float(self.txtMx.text())*\
340                                   numpy.ones(len(self.sld_data.sld_mx))
341
342        if self.txtMy.isModified():
343            self.sld_data.sld_my = float(self.txtMy.text())*\
344                                   numpy.ones(len(self.sld_data.sld_my))
345
346        if self.txtMz.isModified():
347            self.sld_data.sld_mz = float(self.txtMz.text())*\
348                                   numpy.ones(len(self.sld_data.sld_mz))
349
350        if self.txtNucl.isModified():
351            self.sld_data.sld_n = float(self.txtNucl.text())*\
352                                  numpy.ones(len(self.sld_data.sld_n))
353
354        if self.txtXnodes.isModified():
355            self.sld_data.xnodes = int(self.txtXnodes.text())
356
357        if self.txtYnodes.isModified():
358            self.sld_data.ynodes = int(self.txtYnodes.text())
359
360        if self.txtZnodes.isModified():
361            self.sld_data.znodes = int(self.txtZnodes.text())
362
363        if self.txtXstepsize.isModified():
364            self.sld_data.xstepsize = float(self.txtXstepsize.text())
365
366        if self.txtYstepsize.isModified():
367            self.sld_data.ystepsize = float(self.txtYstepsize.text())
368
369        if self.txtZstepsize.isModified():
370            self.sld_data.zstepsize = float(self.txtZstepsize.text())
371
372        if self.cbOptionsCalc.isVisible():
373            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
374
375    def onHelp(self):
376        """
377        Bring up the Generic Scattering calculator Documentation whenever
378        the HELP button is clicked.
379        Calls Documentation Window with the path of the location within the
380        documentation tree (after /doc/ ....".
381        """
382        try:
383            location = GuiUtils.HELP_DIRECTORY_LOCATION + \
384                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
385            self.manager._helpView.load(QtCore.QUrl(location))
386            self.manager._helpView.show()
387        except AttributeError:
388            # No manager defined - testing and standalone runs
389            pass
390
391    def onReset(self):
392        """ Reset the inputs of textEdit to default values """
393        try:
394            # reset values in textedits
395            self.txtUpFracIn.setText("1.0")
396            self.txtUpFracOut.setText("1.0")
397            self.txtUpTheta.setText("0.0")
398            self.txtBackground.setText("0.0")
399            self.txtScale.setText("1.0")
400            self.txtSolventSLD.setText("0.0")
401            self.txtTotalVolume.setText("216000.0")
402            self.txtNoQBins.setText("50")
403            self.txtQxMax.setText("0.3")
404            self.txtNoPixels.setText("1000")
405            self.txtMx.setText("0")
406            self.txtMy.setText("0")
407            self.txtMz.setText("0")
408            self.txtNucl.setText("6.97e-06")
409            self.txtXnodes.setText("10")
410            self.txtYnodes.setText("10")
411            self.txtZnodes.setText("10")
412            self.txtXstepsize.setText("6")
413            self.txtYstepsize.setText("6")
414            self.txtZstepsize.setText("6")
415            # reset Load button and textedit
416            self.txtData.setText('Default SLD Profile')
417            self.cmdLoad.setEnabled(True)
418            self.cmdLoad.setText('Load')
419            # reset option for calculation
420            self.cbOptionsCalc.setCurrentIndex(0)
421            self.cbOptionsCalc.setVisible(False)
422            # reset shape button
423            self.cbShape.setCurrentIndex(0)
424            self.cbShape.setEnabled(True)
425            # reset compute button
426            self.cmdCompute.setText('Compute')
427            self.cmdCompute.setEnabled(True)
428            # TODO reload default data set
429            self._create_default_sld_data()
430
431        finally:
432            pass
433
434    def _create_default_2d_data(self):
435        """
436        Copied from previous version
437        Create 2D data by default
438        :warning: This data is never plotted.
439        """
440        self.qmax_x = float(self.txtQxMax.text())
441        self.npts_x = int(self.txtNoQBins.text())
442        self.data = Data2D()
443        self.data.is_data = False
444        # # Default values
445        self.data.detector.append(Detector())
446        index = len(self.data.detector) - 1
447        self.data.detector[index].distance = 8000  # mm
448        self.data.source.wavelength = 6  # A
449        self.data.detector[index].pixel_size.x = 5  # mm
450        self.data.detector[index].pixel_size.y = 5  # mm
451        self.data.detector[index].beam_center.x = self.qmax_x
452        self.data.detector[index].beam_center.y = self.qmax_x
453        xmax = self.qmax_x
454        xmin = -xmax
455        ymax = self.qmax_x
456        ymin = -ymax
457        qstep = self.npts_x
458
459        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
460        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
461        # use data info instead
462        new_x = numpy.tile(x, (len(y), 1))
463        new_y = numpy.tile(y, (len(x), 1))
464        new_y = new_y.swapaxes(0, 1)
465        # all data require now in 1d array
466        qx_data = new_x.flatten()
467        qy_data = new_y.flatten()
468        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
469        # set all True (standing for unmasked) as default
470        mask = numpy.ones(len(qx_data), dtype=bool)
471        self.data.source = Source()
472        self.data.data = numpy.ones(len(mask))
473        self.data.err_data = numpy.ones(len(mask))
474        self.data.qx_data = qx_data
475        self.data.qy_data = qy_data
476        self.data.q_data = q_data
477        self.data.mask = mask
478        # store x and y bin centers in q space
479        self.data.x_bins = x
480        self.data.y_bins = y
481        # max and min taking account of the bin sizes
482        self.data.xmin = xmin
483        self.data.xmax = xmax
484        self.data.ymin = ymin
485        self.data.ymax = ymax
486
487    def _create_default_sld_data(self):
488        """
489        Copied from previous version
490        Making default sld-data
491        """
492        sld_n_default = 6.97e-06  # what is this number??
493        omfdata = sas_gen.OMFData()
494        omf2sld = sas_gen.OMF2SLD()
495        omf2sld.set_data(omfdata, self.default_shape)
496        self.sld_data = omf2sld.output
497        self.sld_data.is_data = False
498        self.sld_data.filename = "Default SLD Profile"
499        self.sld_data.set_sldn(sld_n_default)
500
501    def _create_default_1d_data(self):
502        """
503        Copied from previous version
504        Create 1D data by default
505        :warning: This data is never plotted.
506                    residuals.x = data_copy.x[index]
507            residuals.dy = numpy.ones(len(residuals.y))
508            residuals.dx = None
509            residuals.dxl = None
510            residuals.dxw = None
511        """
512        self.qmax_x = float(self.txtQxMax.text())
513        self.npts_x = int(self.txtNoQBins.text())
514        #  Default values
515        xmax = self.qmax_x
516        xmin = self.qmax_x * _Q1D_MIN
517        qstep = self.npts_x
518        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
519        # store x and y bin centers in q space
520        y = numpy.ones(len(x))
521        dy = numpy.zeros(len(x))
522        dx = numpy.zeros(len(x))
523        self.data = Data1D(x=x, y=y)
524        self.data.dx = dx
525        self.data.dy = dy
526
527    def onCompute(self):
528        """
529        Copied from previous version
530        Execute the computation of I(qx, qy)
531        """
532        # Set default data when nothing loaded yet
533        if self.sld_data is None:
534            self._create_default_sld_data()
535        try:
536            self.model.set_sld_data(self.sld_data)
537            self.write_new_values_from_gui()
538            if self.is_avg or self.is_avg is None:
539                self._create_default_1d_data()
540                i_out = numpy.zeros(len(self.data.y))
541                inputs = [self.data.x, [], i_out]
542            else:
543                self._create_default_2d_data()
544                i_out = numpy.zeros(len(self.data.data))
545                inputs = [self.data.qx_data, self.data.qy_data, i_out]
546            logging.info("Computation is in progress...")
547            self.cmdCompute.setText('Wait...')
548            self.cmdCompute.setEnabled(False)
549            d = threads.deferToThread(self.complete, inputs, self._update)
550            # Add deferred callback for call return
551            d.addCallback(self.plot_1_2d)
552            d.addErrback(self.calculateFailed)
553        except:
554            log_msg = "{}. stop".format(sys.exc_info()[1])
555            logging.info(log_msg)
556        return
557
558    def _update(self, value):
559        """
560        Copied from previous version
561        """
562        pass
563
564    def calculateFailed(self, reason):
565        """
566        """
567        print("Calculate Failed with:\n", reason)
568        pass
569
570    def complete(self, input, update=None):
571        """
572        Gen compute complete function
573        :Param input: input list [qx_data, qy_data, i_out]
574        """
575        out = numpy.empty(0)
576        for ind in range(len(input[0])):
577            if self.is_avg:
578                if ind % 1 == 0 and update is not None:
579                    # update()
580                    percentage = int(100.0 * float(ind) / len(input[0]))
581                    update(percentage)
582                    time.sleep(0.001)  # 0.1
583                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
584                outi = self.model.run(inputi)
585                out = numpy.append(out, outi)
586            else:
587                if ind % 50 == 0 and update is not None:
588                    percentage = int(100.0 * float(ind) / len(input[0]))
589                    update(percentage)
590                    time.sleep(0.001)
591                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
592                          input[2][ind:ind + 1]]
593                outi = self.model.runXY(inputi)
594                out = numpy.append(out, outi)
595        self.data_to_plot = out
596        logging.info('Gen computation completed.')
597        self.cmdCompute.setText('Compute')
598        self.cmdCompute.setEnabled(True)
599        return
600
601    def onSaveFile(self):
602        """Save data as .sld file"""
603        path = os.path.dirname(str(self.datafile))
604        default_name = os.path.join(path, 'sld_file')
605        kwargs = {
606            'parent': self,
607            'directory': default_name,
608            'filter': 'SLD file (*.sld)',
609            'options': QtGui.QFileDialog.DontUseNativeDialog}
610        # Query user for filename.
611        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
612        if filename:
613            try:
614                if os.path.splitext(filename)[1].lower() == '.sld':
615                    sas_gen.SLDReader().write(filename, self.sld_data)
616                else:
617                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
618                                              self.sld_data)
619            except:
620                raise
621
622    def plot3d(self, has_arrow=False):
623        """ Generate 3D plot in real space with or without arrows """
624        self.write_new_values_from_gui()
625        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
626                                                            self.file_name)
627        if has_arrow:
628            graph_title += ' - Magnetic Vector as Arrow'
629
630        plot3D = Plotter3D(self, graph_title)
631        plot3D.plot(self.sld_data, has_arrow=has_arrow)
632        plot3D.show()
633        self.graph_num += 1
634
635    def plot_1_2d(self, d):
636        """ Generate 1D or 2D plot, called in Compute"""
637        if self.is_avg or self.is_avg is None:
638            data = Data1D(x=self.data.x, y=self.data_to_plot)
639            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
640                                                    int(self.graph_num))
641            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
642            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
643            plot1D = Plotter(self)
644            plot1D.plot(data)
645            plot1D.show()
646            self.graph_num += 1
647            # TODO
648            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
649            return plot1D
650        else:
651            numpy.nan_to_num(self.data_to_plot)
652            data = Data2D(image=self.data_to_plot,
653                          qx_data=self.data.qx_data,
654                          qy_data=self.data.qy_data,
655                          q_data=self.data.q_data,
656                          xmin=self.data.xmin, xmax=self.data.ymax,
657                          ymin=self.data.ymin, ymax=self.data.ymax,
658                          err_image=self.data.err_data)
659            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
660                                                    int(self.graph_num))
661            plot2D = Plotter2D(self)
662            plot2D.plot(data)
663            plot2D.show()
664            self.graph_num += 1
665            # TODO
666            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
667            return plot2D
668
669
670class Plotter3DWidget(PlotterBase):
671    """
672    3D Plot widget for use with a QDialog
673    """
674    def __init__(self, parent=None, manager=None):
675        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
676
677    @property
678    def data(self):
679        return self._data
680
681    @data.setter
682    def data(self, data=None):
683        """ data setter """
684        self._data = data
685
686    def plot(self, data=None, has_arrow=False):
687        """
688        Plot 3D self._data
689        """
690        if not data:
691            return
692        self.data = data
693        #assert(self._data)
694        # Prepare and show the plot
695        self.showPlot(data=self.data, has_arrow=has_arrow)
696
697    def showPlot(self, data, has_arrow=False):
698        """
699        Render and show the current data
700        """
701        # If we don't have any data, skip.
702        if data is None:
703            return
704        # This import takes forever - place it here so the main UI starts faster
705        from mpl_toolkits.mplot3d import Axes3D
706        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
707                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
708        marker = ','
709        m_size = 2
710
711        pos_x = data.pos_x
712        pos_y = data.pos_y
713        pos_z = data.pos_z
714        sld_mx = data.sld_mx
715        sld_my = data.sld_my
716        sld_mz = data.sld_mz
717        pix_symbol = data.pix_symbol
718        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
719                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
720        is_nonzero = sld_tot > 0.0
721        is_zero = sld_tot == 0.0
722
723        if data.pix_type == 'atom':
724            marker = 'o'
725            m_size = 3.5
726
727        self.figure.clear()
728        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
729        ax = Axes3D(self.figure)
730        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
731        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
732        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
733
734        # I. Plot null points
735        if is_zero.any():
736            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
737                           marker, c="y", alpha=0.5, markeredgecolor='y',
738                           markersize=m_size)
739            pos_x = pos_x[is_nonzero]
740            pos_y = pos_y[is_nonzero]
741            pos_z = pos_z[is_nonzero]
742            sld_mx = sld_mx[is_nonzero]
743            sld_my = sld_my[is_nonzero]
744            sld_mz = sld_mz[is_nonzero]
745            pix_symbol = data.pix_symbol[is_nonzero]
746        # II. Plot selective points in color
747        other_color = numpy.ones(len(pix_symbol), dtype='bool')
748        for key in list(color_dic.keys()):
749            chosen_color = pix_symbol == key
750            if numpy.any(chosen_color):
751                other_color = other_color & (chosen_color!=True)
752                color = color_dic[key]
753                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
754                         pos_y[chosen_color], marker, c=color, alpha=0.5,
755                         markeredgecolor=color, markersize=m_size, label=key)
756        # III. Plot All others
757        if numpy.any(other_color):
758            a_name = ''
759            if data.pix_type == 'atom':
760                # Get atom names not in the list
761                a_names = [symb for symb in pix_symbol \
762                           if symb not in list(color_dic.keys())]
763                a_name = a_names[0]
764                for name in a_names:
765                    new_name = ", " + name
766                    if a_name.count(name) == 0:
767                        a_name += new_name
768            # plot in black
769            im = ax.plot(pos_x[other_color], pos_z[other_color],
770                         pos_y[other_color], marker, c="k", alpha=0.5,
771                         markeredgecolor="k", markersize=m_size, label=a_name)
772        if data.pix_type == 'atom':
773            ax.legend(loc='upper left', prop={'size': 10})
774        # IV. Draws atomic bond with grey lines if any
775        if data.has_conect:
776            for ind in range(len(data.line_x)):
777                im = ax.plot(data.line_x[ind], data.line_z[ind],
778                             data.line_y[ind], '-', lw=0.6, c="grey",
779                             alpha=0.3)
780        # V. Draws magnetic vectors
781        if has_arrow and len(pos_x) > 0:
782            def _draw_arrow(input=None, update=None):
783                # import moved here for performance reasons
784                from sas.qtgui.Plotting.Arrow3D import Arrow3D
785                """
786                draw magnetic vectors w/arrow
787                """
788                max_mx = max(numpy.fabs(sld_mx))
789                max_my = max(numpy.fabs(sld_my))
790                max_mz = max(numpy.fabs(sld_mz))
791                max_m = max(max_mx, max_my, max_mz)
792                try:
793                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
794                except:
795                    max_step = 0
796                if max_step <= 0:
797                    max_step = 5
798                try:
799                    if max_m != 0:
800                        unit_x2 = sld_mx / max_m
801                        unit_y2 = sld_my / max_m
802                        unit_z2 = sld_mz / max_m
803                        # 0.8 is for avoiding the color becomes white=(1,1,1))
804                        color_x = numpy.fabs(unit_x2 * 0.8)
805                        color_y = numpy.fabs(unit_y2 * 0.8)
806                        color_z = numpy.fabs(unit_z2 * 0.8)
807                        x2 = pos_x + unit_x2 * max_step
808                        y2 = pos_y + unit_y2 * max_step
809                        z2 = pos_z + unit_z2 * max_step
810                        x_arrow = numpy.column_stack((pos_x, x2))
811                        y_arrow = numpy.column_stack((pos_y, y2))
812                        z_arrow = numpy.column_stack((pos_z, z2))
813                        colors = numpy.column_stack((color_x, color_y, color_z))
814                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
815                                        colors, mutation_scale=10, lw=1,
816                                        arrowstyle="->", alpha=0.5)
817                        ax.add_artist(arrows)
818                except:
819                    pass
820                log_msg = "Arrow Drawing completed.\n"
821                logging.info(log_msg)
822            log_msg = "Arrow Drawing is in progress..."
823            logging.info(log_msg)
824
825            # Defer the drawing of arrows to another thread
826            d = threads.deferToThread(_draw_arrow, ax)
827
828        self.figure.canvas.resizing = False
829        self.figure.canvas.draw()
830
831
832class Plotter3D(QtGui.QDialog, Plotter3DWidget):
833    def __init__(self, parent=None, graph_title=''):
834        self.graph_title = graph_title
835        QtGui.QDialog.__init__(self)
836        Plotter3DWidget.__init__(self, manager=parent)
837        self.setWindowTitle(self.graph_title)
838
Note: See TracBrowser for help on using the repository browser.