source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ c6fb57c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since c6fb57c was 53c771e, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Converted unit tests

  • Property mode set to 100644
File size: 33.4 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34
35    def __init__(self, parent=None):
36        super(GenericScatteringCalculator, self).__init__()
37        self.setupUi(self)
38        self.manager = parent
39        self.communicator = self.manager.communicator()
40        self.model = sas_gen.GenSAS()
41        self.omf_reader = sas_gen.OMFReader()
42        self.sld_reader = sas_gen.SLDReader()
43        self.pdb_reader = sas_gen.PDBReader()
44        self.reader = None
45        self.sld_data = None
46
47        self.parameters = []
48        self.data = None
49        self.datafile = None
50        self.file_name = ''
51        self.ext = None
52        self.default_shape = str(self.cbShape.currentText())
53        self.is_avg = False
54        self.data_to_plot = None
55        self.graph_num = 1  # index for name of graph
56
57        # combox box
58        self.cbOptionsCalc.setVisible(False)
59
60        # push buttons
61        self.cmdClose.clicked.connect(self.accept)
62        self.cmdHelp.clicked.connect(self.onHelp)
63
64        self.cmdLoad.clicked.connect(self.loadFile)
65        self.cmdCompute.clicked.connect(self.onCompute)
66        self.cmdReset.clicked.connect(self.onReset)
67        self.cmdSave.clicked.connect(self.onSaveFile)
68
69        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
70        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
71
72        # validators
73        # scale, volume and background must be positive
74        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
75        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
76                                                          self.txtScale))
77        self.txtBackground.setValidator(QtGui.QRegExpValidator(
78            validat_regex_pos, self.txtBackground))
79        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtTotalVolume))
81
82        # fraction of spin up between 0 and 1
83        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
84        self.txtUpFracIn.setValidator(
85            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
86        self.txtUpFracOut.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
88
89        # 0 < Qmax <= 1000
90        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
91        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
92                                                          self.txtQxMax))
93        self.txtQxMax.textChanged.connect(self.check_value)
94
95        # 2 <= Qbin <= 1000
96        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
97                                                            self.txtNoQBins))
98        self.txtNoQBins.textChanged.connect(self.check_value)
99
100        # plots - 3D in real space
101        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
102
103        # TODO the option Ellipsoid has not been implemented
104        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
105
106        # New font to display angstrom symbol
107        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
108        self.lblUnitSolventSLD.setStyleSheet(new_font)
109        self.lblUnitVolume.setStyleSheet(new_font)
110        self.lbl5.setStyleSheet(new_font)
111        self.lblUnitMx.setStyleSheet(new_font)
112        self.lblUnitMy.setStyleSheet(new_font)
113        self.lblUnitMz.setStyleSheet(new_font)
114        self.lblUnitNucl.setStyleSheet(new_font)
115        self.lblUnitx.setStyleSheet(new_font)
116        self.lblUnity.setStyleSheet(new_font)
117        self.lblUnitz.setStyleSheet(new_font)
118
119    def selectedshapechange(self):
120        """
121        TODO Temporary solution to display information about option 'Ellipsoid'
122        """
123        print("The option Ellipsoid has not been implemented yet.")
124        self.communicator.statusBarUpdateSignal.emit(
125            "The option Ellipsoid has not been implemented yet.")
126
127    def loadFile(self):
128        """
129        Open menu to choose the datafile to load
130        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
131        """
132        try:
133            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
134                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
135                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
136                                          "OMF files (*.OMF *.omf);; "
137                                          "All files (*.*)")[0]
138            if self.datafile:
139                self.default_shape = str(self.cbShape.currentText())
140                self.file_name = os.path.basename(str(self.datafile))
141                self.ext = os.path.splitext(str(self.datafile))[1]
142                if self.ext in self.omf_reader.ext:
143                    loader = self.omf_reader
144                elif self.ext in self.sld_reader.ext:
145                    loader = self.sld_reader
146                elif self.ext in self.pdb_reader.ext:
147                    loader = self.pdb_reader
148                else:
149                    loader = None
150                # disable some entries depending on type of loaded file
151                # (according to documentation)
152                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
153                    self.txtUpFracIn.setEnabled(False)
154                    self.txtUpFracOut.setEnabled(False)
155                    self.txtUpTheta.setEnabled(False)
156
157                if self.reader is not None and self.reader.isrunning():
158                    self.reader.stop()
159                self.cmdLoad.setEnabled(False)
160                self.cmdLoad.setText('Loading...')
161                self.communicator.statusBarUpdateSignal.emit(
162                    "Loading File {}".format(os.path.basename(
163                        str(self.datafile))))
164                self.reader = GenReader(path=str(self.datafile), loader=loader,
165                                        completefn=self.complete_loading,
166                                        updatefn=self.load_update)
167                self.reader.queue()
168        except (RuntimeError, IOError):
169            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
170            logging.info(log_msg)
171            raise
172        return
173
174    def load_update(self):
175        """ Legacy function used in GenRead """
176        if self.reader.isrunning():
177            status_type = "progress"
178        else:
179            status_type = "stop"
180        logging.info(status_type)
181
182    def complete_loading(self, data=None):
183        """ Function used in GenRead"""
184        self.cbShape.setEnabled(False)
185        try:
186            is_pdbdata = False
187            self.txtData.setText(os.path.basename(str(self.datafile)))
188            self.is_avg = False
189            if self.ext in self.omf_reader.ext:
190                gen = sas_gen.OMF2SLD()
191                gen.set_data(data)
192                self.sld_data = gen.get_magsld()
193                self.check_units()
194            elif self.ext in self.sld_reader.ext:
195                self.sld_data = data
196            elif self.ext in self.pdb_reader.ext:
197                self.sld_data = data
198                is_pdbdata = True
199            # Display combobox of orientation only for pdb data
200            self.cbOptionsCalc.setVisible(is_pdbdata)
201            self.update_gui()
202        except IOError:
203            log_msg = "Loading Error: " \
204                      "This file format is not supported for GenSAS."
205            logging.info(log_msg)
206            raise
207        except ValueError:
208            log_msg = "Could not find any data"
209            logging.info(log_msg)
210            raise
211        logging.info("Load Complete")
212        self.cmdLoad.setEnabled(True)
213        self.cmdLoad.setText('Load')
214        self.trigger_plot_3d.emit()
215
216    def check_units(self):
217        """
218        Check if the units from the OMF file correspond to the default ones
219        displayed on the interface.
220        If not, modify the GUI with the correct unit
221        """
222        #  TODO: adopt the convention of font and symbol for the updated values
223        if sas_gen.OMFData().valueunit != 'A^(-2)':
224            value_unit = sas_gen.OMFData().valueunit
225            self.lbl_unitMx.setText(value_unit)
226            self.lbl_unitMy.setText(value_unit)
227            self.lbl_unitMz.setText(value_unit)
228            self.lbl_unitNucl.setText(value_unit)
229        if sas_gen.OMFData().meshunit != 'A':
230            mesh_unit = sas_gen.OMFData().meshunit
231            self.lbl_unitx.setText(mesh_unit)
232            self.lbl_unity.setText(mesh_unit)
233            self.lbl_unitz.setText(mesh_unit)
234            self.lbl_unitVolume.setText(mesh_unit+"^3")
235
236    def check_value(self):
237        """Check range of text edits for QMax and Number of Qbins """
238        text_edit = self.sender()
239        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
240        if text_edit.text():
241            value = float(str(text_edit.text()))
242            if text_edit == self.txtQxMax:
243                if value <= 0 or value > 1000:
244                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
245                else:
246                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
247            elif text_edit == self.txtNoQBins:
248                if value < 2 or value > 1000:
249                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
250                else:
251                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
252
253    def update_gui(self):
254        """ Update the interface with values from loaded data """
255        self.model.set_is_avg(self.is_avg)
256        self.model.set_sld_data(self.sld_data)
257        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
258
259        # add condition for activation of save button
260        self.cmdSave.setEnabled(True)
261
262        # activation of 3D plots' buttons (with and without arrows)
263        self.cmdDraw.setEnabled(self.sld_data is not None)
264        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
265
266        self.txtScale.setText(str(self.model.params['scale']))
267        self.txtBackground.setText(str(self.model.params['background']))
268        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
269
270        # Volume to write to interface: npts x volume of first pixel
271        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
272
273        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
274        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
275        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
276
277        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
278        self.txtNoPixels.setEnabled(False)
279
280        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
281                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
282                           'zstepsize']
283        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
284                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
285                           self.txtXstepsize, self.txtYstepsize,
286                           self.txtZstepsize]
287
288        # Fill right hand side of GUI
289        for indx, item in enumerate(list_parameters):
290            if getattr(self.sld_data, item) is None:
291                list_gui_button[indx].setText('NaN')
292            else:
293                value = getattr(self.sld_data, item)
294                if isinstance(value, numpy.ndarray):
295                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
296                else:
297                    item_for_gui = str(GuiUtils.formatNumber(value, True))
298                list_gui_button[indx].setText(item_for_gui)
299
300        # Enable / disable editing of right hand side of GUI
301        for indx, item in enumerate(list_parameters):
302            if indx < 4:
303                # this condition only applies to Mx,y,z and Nucl
304                value = getattr(self.sld_data, item)
305                enable = self.sld_data.pix_type == 'pixel' \
306                         and numpy.min(value) == numpy.max(value)
307            else:
308                enable = not self.sld_data.is_data
309            list_gui_button[indx].setEnabled(enable)
310
311    def write_new_values_from_gui(self):
312        """
313        update parameters using modified inputs from GUI
314        used before computing
315        """
316        if self.txtScale.isModified():
317            self.model.params['scale'] = float(self.txtScale.text())
318
319        if self.txtBackground.isModified():
320            self.model.params['background'] = float(self.txtBackground.text())
321
322        if self.txtSolventSLD.isModified():
323            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
324
325        # Different condition for total volume to get correct volume after
326        # applying set_sld_data in compute
327        if self.txtTotalVolume.isModified() \
328                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
329            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
330
331        if self.txtUpFracIn.isModified():
332            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
333
334        if self.txtUpFracOut.isModified():
335            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
336
337        if self.txtUpTheta.isModified():
338            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
339
340        if self.txtMx.isModified():
341            self.sld_data.sld_mx = float(self.txtMx.text())*\
342                                   numpy.ones(len(self.sld_data.sld_mx))
343
344        if self.txtMy.isModified():
345            self.sld_data.sld_my = float(self.txtMy.text())*\
346                                   numpy.ones(len(self.sld_data.sld_my))
347
348        if self.txtMz.isModified():
349            self.sld_data.sld_mz = float(self.txtMz.text())*\
350                                   numpy.ones(len(self.sld_data.sld_mz))
351
352        if self.txtNucl.isModified():
353            self.sld_data.sld_n = float(self.txtNucl.text())*\
354                                  numpy.ones(len(self.sld_data.sld_n))
355
356        if self.txtXnodes.isModified():
357            self.sld_data.xnodes = int(self.txtXnodes.text())
358
359        if self.txtYnodes.isModified():
360            self.sld_data.ynodes = int(self.txtYnodes.text())
361
362        if self.txtZnodes.isModified():
363            self.sld_data.znodes = int(self.txtZnodes.text())
364
365        if self.txtXstepsize.isModified():
366            self.sld_data.xstepsize = float(self.txtXstepsize.text())
367
368        if self.txtYstepsize.isModified():
369            self.sld_data.ystepsize = float(self.txtYstepsize.text())
370
371        if self.txtZstepsize.isModified():
372            self.sld_data.zstepsize = float(self.txtZstepsize.text())
373
374        if self.cbOptionsCalc.isVisible():
375            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
376
377    def onHelp(self):
378        """
379        Bring up the Generic Scattering calculator Documentation whenever
380        the HELP button is clicked.
381        Calls Documentation Window with the path of the location within the
382        documentation tree (after /doc/ ....".
383        """
384        try:
385            location = GuiUtils.HELP_DIRECTORY_LOCATION + \
386                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
387            self.manager._helpView.load(QtCore.QUrl(location))
388            self.manager._helpView.show()
389        except AttributeError:
390            # No manager defined - testing and standalone runs
391            pass
392
393    def onReset(self):
394        """ Reset the inputs of textEdit to default values """
395        try:
396            # reset values in textedits
397            self.txtUpFracIn.setText("1.0")
398            self.txtUpFracOut.setText("1.0")
399            self.txtUpTheta.setText("0.0")
400            self.txtBackground.setText("0.0")
401            self.txtScale.setText("1.0")
402            self.txtSolventSLD.setText("0.0")
403            self.txtTotalVolume.setText("216000.0")
404            self.txtNoQBins.setText("50")
405            self.txtQxMax.setText("0.3")
406            self.txtNoPixels.setText("1000")
407            self.txtMx.setText("0")
408            self.txtMy.setText("0")
409            self.txtMz.setText("0")
410            self.txtNucl.setText("6.97e-06")
411            self.txtXnodes.setText("10")
412            self.txtYnodes.setText("10")
413            self.txtZnodes.setText("10")
414            self.txtXstepsize.setText("6")
415            self.txtYstepsize.setText("6")
416            self.txtZstepsize.setText("6")
417            # reset Load button and textedit
418            self.txtData.setText('Default SLD Profile')
419            self.cmdLoad.setEnabled(True)
420            self.cmdLoad.setText('Load')
421            # reset option for calculation
422            self.cbOptionsCalc.setCurrentIndex(0)
423            self.cbOptionsCalc.setVisible(False)
424            # reset shape button
425            self.cbShape.setCurrentIndex(0)
426            self.cbShape.setEnabled(True)
427            # reset compute button
428            self.cmdCompute.setText('Compute')
429            self.cmdCompute.setEnabled(True)
430            # TODO reload default data set
431            self._create_default_sld_data()
432
433        finally:
434            pass
435
436    def _create_default_2d_data(self):
437        """
438        Copied from previous version
439        Create 2D data by default
440        :warning: This data is never plotted.
441        """
442        self.qmax_x = float(self.txtQxMax.text())
443        self.npts_x = int(self.txtNoQBins.text())
444        self.data = Data2D()
445        self.data.is_data = False
446        # # Default values
447        self.data.detector.append(Detector())
448        index = len(self.data.detector) - 1
449        self.data.detector[index].distance = 8000  # mm
450        self.data.source.wavelength = 6  # A
451        self.data.detector[index].pixel_size.x = 5  # mm
452        self.data.detector[index].pixel_size.y = 5  # mm
453        self.data.detector[index].beam_center.x = self.qmax_x
454        self.data.detector[index].beam_center.y = self.qmax_x
455        xmax = self.qmax_x
456        xmin = -xmax
457        ymax = self.qmax_x
458        ymin = -ymax
459        qstep = self.npts_x
460
461        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
462        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
463        # use data info instead
464        new_x = numpy.tile(x, (len(y), 1))
465        new_y = numpy.tile(y, (len(x), 1))
466        new_y = new_y.swapaxes(0, 1)
467        # all data require now in 1d array
468        qx_data = new_x.flatten()
469        qy_data = new_y.flatten()
470        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
471        # set all True (standing for unmasked) as default
472        mask = numpy.ones(len(qx_data), dtype=bool)
473        self.data.source = Source()
474        self.data.data = numpy.ones(len(mask))
475        self.data.err_data = numpy.ones(len(mask))
476        self.data.qx_data = qx_data
477        self.data.qy_data = qy_data
478        self.data.q_data = q_data
479        self.data.mask = mask
480        # store x and y bin centers in q space
481        self.data.x_bins = x
482        self.data.y_bins = y
483        # max and min taking account of the bin sizes
484        self.data.xmin = xmin
485        self.data.xmax = xmax
486        self.data.ymin = ymin
487        self.data.ymax = ymax
488
489    def _create_default_sld_data(self):
490        """
491        Copied from previous version
492        Making default sld-data
493        """
494        sld_n_default = 6.97e-06  # what is this number??
495        omfdata = sas_gen.OMFData()
496        omf2sld = sas_gen.OMF2SLD()
497        omf2sld.set_data(omfdata, self.default_shape)
498        self.sld_data = omf2sld.output
499        self.sld_data.is_data = False
500        self.sld_data.filename = "Default SLD Profile"
501        self.sld_data.set_sldn(sld_n_default)
502
503    def _create_default_1d_data(self):
504        """
505        Copied from previous version
506        Create 1D data by default
507        :warning: This data is never plotted.
508                    residuals.x = data_copy.x[index]
509            residuals.dy = numpy.ones(len(residuals.y))
510            residuals.dx = None
511            residuals.dxl = None
512            residuals.dxw = None
513        """
514        self.qmax_x = float(self.txtQxMax.text())
515        self.npts_x = int(self.txtNoQBins.text())
516        #  Default values
517        xmax = self.qmax_x
518        xmin = self.qmax_x * _Q1D_MIN
519        qstep = self.npts_x
520        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
521        # store x and y bin centers in q space
522        y = numpy.ones(len(x))
523        dy = numpy.zeros(len(x))
524        dx = numpy.zeros(len(x))
525        self.data = Data1D(x=x, y=y)
526        self.data.dx = dx
527        self.data.dy = dy
528
529    def onCompute(self):
530        """
531        Copied from previous version
532        Execute the computation of I(qx, qy)
533        """
534        # Set default data when nothing loaded yet
535        if self.sld_data is None:
536            self._create_default_sld_data()
537        try:
538            self.model.set_sld_data(self.sld_data)
539            self.write_new_values_from_gui()
540            if self.is_avg or self.is_avg is None:
541                self._create_default_1d_data()
542                i_out = numpy.zeros(len(self.data.y))
543                inputs = [self.data.x, [], i_out]
544            else:
545                self._create_default_2d_data()
546                i_out = numpy.zeros(len(self.data.data))
547                inputs = [self.data.qx_data, self.data.qy_data, i_out]
548            logging.info("Computation is in progress...")
549            self.cmdCompute.setText('Wait...')
550            self.cmdCompute.setEnabled(False)
551            d = threads.deferToThread(self.complete, inputs, self._update)
552            # Add deferred callback for call return
553            d.addCallback(self.plot_1_2d)
554            d.addErrback(self.calculateFailed)
555        except:
556            log_msg = "{}. stop".format(sys.exc_info()[1])
557            logging.info(log_msg)
558        return
559
560    def _update(self, value):
561        """
562        Copied from previous version
563        """
564        pass
565
566    def calculateFailed(self, reason):
567        """
568        """
569        print("Calculate Failed with:\n", reason)
570        pass
571
572    def complete(self, input, update=None):
573        """
574        Gen compute complete function
575        :Param input: input list [qx_data, qy_data, i_out]
576        """
577        out = numpy.empty(0)
578        for ind in range(len(input[0])):
579            if self.is_avg:
580                if ind % 1 == 0 and update is not None:
581                    # update()
582                    percentage = int(100.0 * float(ind) / len(input[0]))
583                    update(percentage)
584                    time.sleep(0.001)  # 0.1
585                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
586                outi = self.model.run(inputi)
587                out = numpy.append(out, outi)
588            else:
589                if ind % 50 == 0 and update is not None:
590                    percentage = int(100.0 * float(ind) / len(input[0]))
591                    update(percentage)
592                    time.sleep(0.001)
593                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
594                          input[2][ind:ind + 1]]
595                outi = self.model.runXY(inputi)
596                out = numpy.append(out, outi)
597        self.data_to_plot = out
598        logging.info('Gen computation completed.')
599        self.cmdCompute.setText('Compute')
600        self.cmdCompute.setEnabled(True)
601        return
602
603    def onSaveFile(self):
604        """Save data as .sld file"""
605        path = os.path.dirname(str(self.datafile))
606        default_name = os.path.join(path, 'sld_file')
607        kwargs = {
608            'parent': self,
609            'directory': default_name,
610            'filter': 'SLD file (*.sld)',
611            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
612        # Query user for filename.
613        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
614        filename = filename_tuple[0]
615        if filename:
616            try:
617                _, extension = os.path.splitext(filename)
618                if not extension:
619                    filename = '.'.join((filename, 'sld'))
620                sas_gen.SLDReader().write(filename, self.sld_data)
621            except:
622                raise
623
624    def plot3d(self, has_arrow=False):
625        """ Generate 3D plot in real space with or without arrows """
626        self.write_new_values_from_gui()
627        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
628                                                            self.file_name)
629        if has_arrow:
630            graph_title += ' - Magnetic Vector as Arrow'
631
632        plot3D = Plotter3D(self, graph_title)
633        plot3D.plot(self.sld_data, has_arrow=has_arrow)
634        plot3D.show()
635        self.graph_num += 1
636
637    def plot_1_2d(self, d):
638        """ Generate 1D or 2D plot, called in Compute"""
639        if self.is_avg or self.is_avg is None:
640            data = Data1D(x=self.data.x, y=self.data_to_plot)
641            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
642                                                    int(self.graph_num))
643            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
644            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
645            plot1D = Plotter(self)
646            plot1D.plot(data)
647            plot1D.show()
648            self.graph_num += 1
649            # TODO
650            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
651            return plot1D
652        else:
653            numpy.nan_to_num(self.data_to_plot)
654            data = Data2D(image=self.data_to_plot,
655                          qx_data=self.data.qx_data,
656                          qy_data=self.data.qy_data,
657                          q_data=self.data.q_data,
658                          xmin=self.data.xmin, xmax=self.data.ymax,
659                          ymin=self.data.ymin, ymax=self.data.ymax,
660                          err_image=self.data.err_data)
661            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
662                                                    int(self.graph_num))
663            plot2D = Plotter2D(self)
664            plot2D.plot(data)
665            plot2D.show()
666            self.graph_num += 1
667            # TODO
668            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
669            return plot2D
670
671
672class Plotter3DWidget(PlotterBase):
673    """
674    3D Plot widget for use with a QDialog
675    """
676    def __init__(self, parent=None, manager=None):
677        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
678
679    @property
680    def data(self):
681        return self._data
682
683    @data.setter
684    def data(self, data=None):
685        """ data setter """
686        self._data = data
687
688    def plot(self, data=None, has_arrow=False):
689        """
690        Plot 3D self._data
691        """
692        if not data:
693            return
694        self.data = data
695        #assert(self._data)
696        # Prepare and show the plot
697        self.showPlot(data=self.data, has_arrow=has_arrow)
698
699    def showPlot(self, data, has_arrow=False):
700        """
701        Render and show the current data
702        """
703        # If we don't have any data, skip.
704        if data is None:
705            return
706        # This import takes forever - place it here so the main UI starts faster
707        from mpl_toolkits.mplot3d import Axes3D
708        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
709                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
710        marker = ','
711        m_size = 2
712
713        pos_x = data.pos_x
714        pos_y = data.pos_y
715        pos_z = data.pos_z
716        sld_mx = data.sld_mx
717        sld_my = data.sld_my
718        sld_mz = data.sld_mz
719        pix_symbol = data.pix_symbol
720        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
721                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
722        is_nonzero = sld_tot > 0.0
723        is_zero = sld_tot == 0.0
724
725        if data.pix_type == 'atom':
726            marker = 'o'
727            m_size = 3.5
728
729        self.figure.clear()
730        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
731        ax = Axes3D(self.figure)
732        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
733        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
734        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
735
736        # I. Plot null points
737        if is_zero.any():
738            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
739                           marker, c="y", alpha=0.5, markeredgecolor='y',
740                           markersize=m_size)
741            pos_x = pos_x[is_nonzero]
742            pos_y = pos_y[is_nonzero]
743            pos_z = pos_z[is_nonzero]
744            sld_mx = sld_mx[is_nonzero]
745            sld_my = sld_my[is_nonzero]
746            sld_mz = sld_mz[is_nonzero]
747            pix_symbol = data.pix_symbol[is_nonzero]
748        # II. Plot selective points in color
749        other_color = numpy.ones(len(pix_symbol), dtype='bool')
750        for key in list(color_dic.keys()):
751            chosen_color = pix_symbol == key
752            if numpy.any(chosen_color):
753                other_color = other_color & (chosen_color!=True)
754                color = color_dic[key]
755                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
756                         pos_y[chosen_color], marker, c=color, alpha=0.5,
757                         markeredgecolor=color, markersize=m_size, label=key)
758        # III. Plot All others
759        if numpy.any(other_color):
760            a_name = ''
761            if data.pix_type == 'atom':
762                # Get atom names not in the list
763                a_names = [symb for symb in pix_symbol \
764                           if symb not in list(color_dic.keys())]
765                a_name = a_names[0]
766                for name in a_names:
767                    new_name = ", " + name
768                    if a_name.count(name) == 0:
769                        a_name += new_name
770            # plot in black
771            im = ax.plot(pos_x[other_color], pos_z[other_color],
772                         pos_y[other_color], marker, c="k", alpha=0.5,
773                         markeredgecolor="k", markersize=m_size, label=a_name)
774        if data.pix_type == 'atom':
775            ax.legend(loc='upper left', prop={'size': 10})
776        # IV. Draws atomic bond with grey lines if any
777        if data.has_conect:
778            for ind in range(len(data.line_x)):
779                im = ax.plot(data.line_x[ind], data.line_z[ind],
780                             data.line_y[ind], '-', lw=0.6, c="grey",
781                             alpha=0.3)
782        # V. Draws magnetic vectors
783        if has_arrow and len(pos_x) > 0:
784            def _draw_arrow(input=None, update=None):
785                # import moved here for performance reasons
786                from sas.qtgui.Plotting.Arrow3D import Arrow3D
787                """
788                draw magnetic vectors w/arrow
789                """
790                max_mx = max(numpy.fabs(sld_mx))
791                max_my = max(numpy.fabs(sld_my))
792                max_mz = max(numpy.fabs(sld_mz))
793                max_m = max(max_mx, max_my, max_mz)
794                try:
795                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
796                except:
797                    max_step = 0
798                if max_step <= 0:
799                    max_step = 5
800                try:
801                    if max_m != 0:
802                        unit_x2 = sld_mx / max_m
803                        unit_y2 = sld_my / max_m
804                        unit_z2 = sld_mz / max_m
805                        # 0.8 is for avoiding the color becomes white=(1,1,1))
806                        color_x = numpy.fabs(unit_x2 * 0.8)
807                        color_y = numpy.fabs(unit_y2 * 0.8)
808                        color_z = numpy.fabs(unit_z2 * 0.8)
809                        x2 = pos_x + unit_x2 * max_step
810                        y2 = pos_y + unit_y2 * max_step
811                        z2 = pos_z + unit_z2 * max_step
812                        x_arrow = numpy.column_stack((pos_x, x2))
813                        y_arrow = numpy.column_stack((pos_y, y2))
814                        z_arrow = numpy.column_stack((pos_z, z2))
815                        colors = numpy.column_stack((color_x, color_y, color_z))
816                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
817                                        colors, mutation_scale=10, lw=1,
818                                        arrowstyle="->", alpha=0.5)
819                        ax.add_artist(arrows)
820                except:
821                    pass
822                log_msg = "Arrow Drawing completed.\n"
823                logging.info(log_msg)
824            log_msg = "Arrow Drawing is in progress..."
825            logging.info(log_msg)
826
827            # Defer the drawing of arrows to another thread
828            d = threads.deferToThread(_draw_arrow, ax)
829
830        self.figure.canvas.resizing = False
831        self.figure.canvas.draw()
832
833
834class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
835    def __init__(self, parent=None, graph_title=''):
836        self.graph_title = graph_title
837        QtWidgets.QDialog.__init__(self)
838        Plotter3DWidget.__init__(self, manager=parent)
839        self.setWindowTitle(self.graph_title)
840
Note: See TracBrowser for help on using the repository browser.