source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ ebf86f1

ESS_GUIESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since ebf86f1 was ebf86f1, checked in by wojciech, 8 months ago

Fixing failing Generic Scattering Calculator on OMF file on Linux

  • Property mode set to 100644
File size: 34.3 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt5 import QtCore
8from PyQt5 import QtGui
9from PyQt5 import QtWidgets
10
11from twisted.internet import threads
12
13import sas.qtgui.Utilities.GuiUtils as GuiUtils
14from sas.qtgui.Utilities.GenericReader import GenReader
15from sas.sascalc.dataloader.data_info import Detector
16from sas.sascalc.dataloader.data_info import Source
17from sas.sascalc.calculator import sas_gen
18from sas.qtgui.Plotting.PlotterBase import PlotterBase
19from sas.qtgui.Plotting.Plotter2D import Plotter2D
20from sas.qtgui.Plotting.Plotter import Plotter
21
22from sas.qtgui.Plotting.PlotterData import Data1D
23from sas.qtgui.Plotting.PlotterData import Data2D
24
25# Local UI
26from .UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
27
28_Q1D_MIN = 0.001
29
30
31class GenericScatteringCalculator(QtWidgets.QDialog, Ui_GenericScatteringCalculator):
32
33    trigger_plot_3d = QtCore.pyqtSignal()
34    calculationFinishedSignal = QtCore.pyqtSignal()
35    loadingFinishedSignal = QtCore.pyqtSignal(list)
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        self.manager = parent
41        self.communicator = self.manager.communicator()
42        self.model = sas_gen.GenSAS()
43        self.omf_reader = sas_gen.OMFReader()
44        self.sld_reader = sas_gen.SLDReader()
45        self.pdb_reader = sas_gen.PDBReader()
46        self.reader = None
47        self.sld_data = None
48
49        self.parameters = []
50        self.data = None
51        self.datafile = None
52        self.file_name = ''
53        self.ext = None
54        self.default_shape = str(self.cbShape.currentText())
55        self.is_avg = False
56        self.data_to_plot = None
57        self.graph_num = 1  # index for name of graph
58
59        # combox box
60        self.cbOptionsCalc.setVisible(False)
61
62        # push buttons
63        self.cmdClose.clicked.connect(self.accept)
64        self.cmdHelp.clicked.connect(self.onHelp)
65
66        self.cmdLoad.clicked.connect(self.loadFile)
67        self.cmdCompute.clicked.connect(self.onCompute)
68        self.cmdReset.clicked.connect(self.onReset)
69        self.cmdSave.clicked.connect(self.onSaveFile)
70
71        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
72        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
73
74        # validators
75        # scale, volume and background must be positive
76        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
77        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
78                                                          self.txtScale))
79        self.txtBackground.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtBackground))
81        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
82            validat_regex_pos, self.txtTotalVolume))
83
84        # fraction of spin up between 0 and 1
85        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
86        self.txtUpFracIn.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
88        self.txtUpFracOut.setValidator(
89            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
90
91        # 0 < Qmax <= 1000
92        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
93        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
94                                                          self.txtQxMax))
95        self.txtQxMax.textChanged.connect(self.check_value)
96
97        # 2 <= Qbin <= 1000
98        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
99                                                            self.txtNoQBins))
100        self.txtNoQBins.textChanged.connect(self.check_value)
101
102        # plots - 3D in real space
103        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
104
105        # plots - 3D in real space
106        self.calculationFinishedSignal.connect(self.plot_1_2d)
107
108        # notify main thread about file load complete
109        self.loadingFinishedSignal.connect(self.complete_loading)
110
111        # TODO the option Ellipsoid has not been implemented
112        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
113
114        # New font to display angstrom symbol
115        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
116        self.lblUnitSolventSLD.setStyleSheet(new_font)
117        self.lblUnitVolume.setStyleSheet(new_font)
118        self.lbl5.setStyleSheet(new_font)
119        self.lblUnitMx.setStyleSheet(new_font)
120        self.lblUnitMy.setStyleSheet(new_font)
121        self.lblUnitMz.setStyleSheet(new_font)
122        self.lblUnitNucl.setStyleSheet(new_font)
123        self.lblUnitx.setStyleSheet(new_font)
124        self.lblUnity.setStyleSheet(new_font)
125        self.lblUnitz.setStyleSheet(new_font)
126
127    def selectedshapechange(self):
128        """
129        TODO Temporary solution to display information about option 'Ellipsoid'
130        """
131        print("The option Ellipsoid has not been implemented yet.")
132        self.communicator.statusBarUpdateSignal.emit(
133            "The option Ellipsoid has not been implemented yet.")
134
135    def loadFile(self):
136        """
137        Open menu to choose the datafile to load
138        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
139        """
140        try:
141            self.datafile = QtWidgets.QFileDialog.getOpenFileName(
142                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
143                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
144                                          "OMF files (*.OMF *.omf);; "
145                                          "All files (*.*)")[0]
146            if self.datafile:
147                self.default_shape = str(self.cbShape.currentText())
148                self.file_name = os.path.basename(str(self.datafile))
149                self.ext = os.path.splitext(str(self.datafile))[1]
150                if self.ext in self.omf_reader.ext:
151                    loader = self.omf_reader
152                elif self.ext in self.sld_reader.ext:
153                    loader = self.sld_reader
154                elif self.ext in self.pdb_reader.ext:
155                    loader = self.pdb_reader
156                else:
157                    loader = None
158                # disable some entries depending on type of loaded file
159                # (according to documentation)
160                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
161                    self.txtUpFracIn.setEnabled(False)
162                    self.txtUpFracOut.setEnabled(False)
163                    self.txtUpTheta.setEnabled(False)
164
165                if self.reader is not None and self.reader.isrunning():
166                    self.reader.stop()
167                self.cmdLoad.setEnabled(False)
168                self.cmdLoad.setText('Loading...')
169                self.communicator.statusBarUpdateSignal.emit(
170                    "Loading File {}".format(os.path.basename(
171                        str(self.datafile))))
172                self.reader = GenReader(path=str(self.datafile), loader=loader,
173                                        completefn=self.complete_loading_ex,
174                                        updatefn=self.load_update)
175                self.reader.queue()
176        except (RuntimeError, IOError):
177            log_msg = "Generic SAS Calculator: %s" % sys.exc_info()[1]
178            logging.info(log_msg)
179            raise
180        return
181
182    def load_update(self):
183        """ Legacy function used in GenRead """
184        if self.reader.isrunning():
185            status_type = "progress"
186        else:
187            status_type = "stop"
188        logging.info(status_type)
189
190    def complete_loading_ex(self, data=None):
191        """
192        Send the finish message from calculate threads to main thread
193        """
194        self.loadingFinishedSignal.emit(data)
195
196    def complete_loading(self, data=None):
197        """ Function used in GenRead"""
198        assert isinstance(data, list)
199        assert len(data)==1
200        data = data[0]
201        self.cbShape.setEnabled(False)
202        try:
203            is_pdbdata = False
204            self.txtData.setText(os.path.basename(str(self.datafile)))
205            self.is_avg = False
206            if self.ext in self.omf_reader.ext:
207                gen = sas_gen.OMF2SLD()
208                gen.set_data(data)
209                self.sld_data = gen.get_magsld()
210                self.check_units()
211            elif self.ext in self.sld_reader.ext:
212                self.sld_data = data
213            elif self.ext in self.pdb_reader.ext:
214                self.sld_data = data
215                is_pdbdata = True
216            # Display combobox of orientation only for pdb data
217            self.cbOptionsCalc.setVisible(is_pdbdata)
218            self.update_gui()
219        except IOError:
220            log_msg = "Loading Error: " \
221                      "This file format is not supported for GenSAS."
222            logging.info(log_msg)
223            raise
224        except ValueError:
225            log_msg = "Could not find any data"
226            logging.info(log_msg)
227            raise
228        logging.info("Load Complete")
229        self.cmdLoad.setEnabled(True)
230        self.cmdLoad.setText('Load')
231        self.trigger_plot_3d.emit()
232
233    def check_units(self):
234        """
235        Check if the units from the OMF file correspond to the default ones
236        displayed on the interface.
237        If not, modify the GUI with the correct unit
238        """
239        #  TODO: adopt the convention of font and symbol for the updated values
240        if sas_gen.OMFData().valueunit != 'A^(-2)':
241            value_unit = sas_gen.OMFData().valueunit
242            self.lbl_unitMx.setText(value_unit)
243            self.lbl_unitMy.setText(value_unit)
244            self.lbl_unitMz.setText(value_unit)
245            self.lbl_unitNucl.setText(value_unit)
246        if sas_gen.OMFData().meshunit != 'A':
247            mesh_unit = sas_gen.OMFData().meshunit
248            self.lbl_unitx.setText(mesh_unit)
249            self.lbl_unity.setText(mesh_unit)
250            self.lbl_unitz.setText(mesh_unit)
251            self.lbl_unitVolume.setText(mesh_unit+"^3")
252
253    def check_value(self):
254        """Check range of text edits for QMax and Number of Qbins """
255        text_edit = self.sender()
256        text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
257        if text_edit.text():
258            value = float(str(text_edit.text()))
259            if text_edit == self.txtQxMax:
260                if value <= 0 or value > 1000:
261                    text_edit.setStyleSheet('background-color: rgb(255, 182, 193);')
262                else:
263                    text_edit.setStyleSheet('background-color: rgb(255, 255, 255);')
264            elif text_edit == self.txtNoQBins:
265                if value < 2 or value > 1000:
266                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 182, 193);')
267                else:
268                    self.txtNoQBins.setStyleSheet('background-color: rgb(255, 255, 255);')
269
270    def update_gui(self):
271        """ Update the interface with values from loaded data """
272        self.model.set_is_avg(self.is_avg)
273        self.model.set_sld_data(self.sld_data)
274        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
275
276        # add condition for activation of save button
277        self.cmdSave.setEnabled(True)
278
279        # activation of 3D plots' buttons (with and without arrows)
280        self.cmdDraw.setEnabled(self.sld_data is not None)
281        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
282
283        self.txtScale.setText(str(self.model.params['scale']))
284        self.txtBackground.setText(str(self.model.params['background']))
285        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
286
287        # Volume to write to interface: npts x volume of first pixel
288        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
289
290        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
291        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
292        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
293
294        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
295        self.txtNoPixels.setEnabled(False)
296
297        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
298                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
299                           'zstepsize']
300        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
301                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
302                           self.txtXstepsize, self.txtYstepsize,
303                           self.txtZstepsize]
304
305        # Fill right hand side of GUI
306        for indx, item in enumerate(list_parameters):
307            if getattr(self.sld_data, item) is None:
308                list_gui_button[indx].setText('NaN')
309            else:
310                value = getattr(self.sld_data, item)
311                if isinstance(value, numpy.ndarray):
312                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
313                else:
314                    item_for_gui = str(GuiUtils.formatNumber(value, True))
315                list_gui_button[indx].setText(item_for_gui)
316
317        # Enable / disable editing of right hand side of GUI
318        for indx, item in enumerate(list_parameters):
319            if indx < 4:
320                # this condition only applies to Mx,y,z and Nucl
321                value = getattr(self.sld_data, item)
322                enable = self.sld_data.pix_type == 'pixel' \
323                         and numpy.min(value) == numpy.max(value)
324            else:
325                enable = not self.sld_data.is_data
326            list_gui_button[indx].setEnabled(enable)
327
328    def write_new_values_from_gui(self):
329        """
330        update parameters using modified inputs from GUI
331        used before computing
332        """
333        if self.txtScale.isModified():
334            self.model.params['scale'] = float(self.txtScale.text())
335
336        if self.txtBackground.isModified():
337            self.model.params['background'] = float(self.txtBackground.text())
338
339        if self.txtSolventSLD.isModified():
340            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
341
342        # Different condition for total volume to get correct volume after
343        # applying set_sld_data in compute
344        if self.txtTotalVolume.isModified() \
345                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
346            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
347
348        if self.txtUpFracIn.isModified():
349            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
350
351        if self.txtUpFracOut.isModified():
352            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
353
354        if self.txtUpTheta.isModified():
355            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
356
357        if self.txtMx.isModified():
358            self.sld_data.sld_mx = float(self.txtMx.text())*\
359                                   numpy.ones(len(self.sld_data.sld_mx))
360
361        if self.txtMy.isModified():
362            self.sld_data.sld_my = float(self.txtMy.text())*\
363                                   numpy.ones(len(self.sld_data.sld_my))
364
365        if self.txtMz.isModified():
366            self.sld_data.sld_mz = float(self.txtMz.text())*\
367                                   numpy.ones(len(self.sld_data.sld_mz))
368
369        if self.txtNucl.isModified():
370            self.sld_data.sld_n = float(self.txtNucl.text())*\
371                                  numpy.ones(len(self.sld_data.sld_n))
372
373        if self.txtXnodes.isModified():
374            self.sld_data.xnodes = int(self.txtXnodes.text())
375
376        if self.txtYnodes.isModified():
377            self.sld_data.ynodes = int(self.txtYnodes.text())
378
379        if self.txtZnodes.isModified():
380            self.sld_data.znodes = int(self.txtZnodes.text())
381
382        if self.txtXstepsize.isModified():
383            self.sld_data.xstepsize = float(self.txtXstepsize.text())
384
385        if self.txtYstepsize.isModified():
386            self.sld_data.ystepsize = float(self.txtYstepsize.text())
387
388        if self.txtZstepsize.isModified():
389            self.sld_data.zstepsize = float(self.txtZstepsize.text())
390
391        if self.cbOptionsCalc.isVisible():
392            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
393
394    def onHelp(self):
395        """
396        Bring up the Generic Scattering calculator Documentation whenever
397        the HELP button is clicked.
398        Calls Documentation Window with the path of the location within the
399        documentation tree (after /doc/ ....".
400        """
401        location = "/user/qtgui/Calculators/sas_calculator_help.html"
402        self.manager.showHelp(location)
403
404    def onReset(self):
405        """ Reset the inputs of textEdit to default values """
406        try:
407            # reset values in textedits
408            self.txtUpFracIn.setText("1.0")
409            self.txtUpFracOut.setText("1.0")
410            self.txtUpTheta.setText("0.0")
411            self.txtBackground.setText("0.0")
412            self.txtScale.setText("1.0")
413            self.txtSolventSLD.setText("0.0")
414            self.txtTotalVolume.setText("216000.0")
415            self.txtNoQBins.setText("50")
416            self.txtQxMax.setText("0.3")
417            self.txtNoPixels.setText("1000")
418            self.txtMx.setText("0")
419            self.txtMy.setText("0")
420            self.txtMz.setText("0")
421            self.txtNucl.setText("6.97e-06")
422            self.txtXnodes.setText("10")
423            self.txtYnodes.setText("10")
424            self.txtZnodes.setText("10")
425            self.txtXstepsize.setText("6")
426            self.txtYstepsize.setText("6")
427            self.txtZstepsize.setText("6")
428            # reset Load button and textedit
429            self.txtData.setText('Default SLD Profile')
430            self.cmdLoad.setEnabled(True)
431            self.cmdLoad.setText('Load')
432            # reset option for calculation
433            self.cbOptionsCalc.setCurrentIndex(0)
434            self.cbOptionsCalc.setVisible(False)
435            # reset shape button
436            self.cbShape.setCurrentIndex(0)
437            self.cbShape.setEnabled(True)
438            # reset compute button
439            self.cmdCompute.setText('Compute')
440            self.cmdCompute.setEnabled(True)
441            # TODO reload default data set
442            self._create_default_sld_data()
443
444        finally:
445            pass
446
447    def _create_default_2d_data(self):
448        """
449        Copied from previous version
450        Create 2D data by default
451        :warning: This data is never plotted.
452        """
453        self.qmax_x = float(self.txtQxMax.text())
454        self.npts_x = int(self.txtNoQBins.text())
455        self.data = Data2D()
456        self.data.is_data = False
457        # # Default values
458        self.data.detector.append(Detector())
459        index = len(self.data.detector) - 1
460        self.data.detector[index].distance = 8000  # mm
461        self.data.source.wavelength = 6  # A
462        self.data.detector[index].pixel_size.x = 5  # mm
463        self.data.detector[index].pixel_size.y = 5  # mm
464        self.data.detector[index].beam_center.x = self.qmax_x
465        self.data.detector[index].beam_center.y = self.qmax_x
466        xmax = self.qmax_x
467        xmin = -xmax
468        ymax = self.qmax_x
469        ymin = -ymax
470        qstep = self.npts_x
471
472        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
473        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
474        # use data info instead
475        new_x = numpy.tile(x, (len(y), 1))
476        new_y = numpy.tile(y, (len(x), 1))
477        new_y = new_y.swapaxes(0, 1)
478        # all data require now in 1d array
479        qx_data = new_x.flatten()
480        qy_data = new_y.flatten()
481        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
482        # set all True (standing for unmasked) as default
483        mask = numpy.ones(len(qx_data), dtype=bool)
484        self.data.source = Source()
485        self.data.data = numpy.ones(len(mask))
486        self.data.err_data = numpy.ones(len(mask))
487        self.data.qx_data = qx_data
488        self.data.qy_data = qy_data
489        self.data.q_data = q_data
490        self.data.mask = mask
491        # store x and y bin centers in q space
492        self.data.x_bins = x
493        self.data.y_bins = y
494        # max and min taking account of the bin sizes
495        self.data.xmin = xmin
496        self.data.xmax = xmax
497        self.data.ymin = ymin
498        self.data.ymax = ymax
499
500    def _create_default_sld_data(self):
501        """
502        Copied from previous version
503        Making default sld-data
504        """
505        sld_n_default = 6.97e-06  # what is this number??
506        omfdata = sas_gen.OMFData()
507        omf2sld = sas_gen.OMF2SLD()
508        omf2sld.set_data(omfdata, self.default_shape)
509        self.sld_data = omf2sld.output
510        self.sld_data.is_data = False
511        self.sld_data.filename = "Default SLD Profile"
512        self.sld_data.set_sldn(sld_n_default)
513
514    def _create_default_1d_data(self):
515        """
516        Copied from previous version
517        Create 1D data by default
518        :warning: This data is never plotted.
519                    residuals.x = data_copy.x[index]
520            residuals.dy = numpy.ones(len(residuals.y))
521            residuals.dx = None
522            residuals.dxl = None
523            residuals.dxw = None
524        """
525        self.qmax_x = float(self.txtQxMax.text())
526        self.npts_x = int(self.txtNoQBins.text())
527        #  Default values
528        xmax = self.qmax_x
529        xmin = self.qmax_x * _Q1D_MIN
530        qstep = self.npts_x
531        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
532        # store x and y bin centers in q space
533        y = numpy.ones(len(x))
534        dy = numpy.zeros(len(x))
535        dx = numpy.zeros(len(x))
536        self.data = Data1D(x=x, y=y)
537        self.data.dx = dx
538        self.data.dy = dy
539
540    def onCompute(self):
541        """
542        Copied from previous version
543        Execute the computation of I(qx, qy)
544        """
545        # Set default data when nothing loaded yet
546        if self.sld_data is None:
547            self._create_default_sld_data()
548        try:
549            self.model.set_sld_data(self.sld_data)
550            self.write_new_values_from_gui()
551            if self.is_avg or self.is_avg is None:
552                self._create_default_1d_data()
553                i_out = numpy.zeros(len(self.data.y))
554                inputs = [self.data.x, [], i_out]
555            else:
556                self._create_default_2d_data()
557                i_out = numpy.zeros(len(self.data.data))
558                inputs = [self.data.qx_data, self.data.qy_data, i_out]
559            logging.info("Computation is in progress...")
560            self.cmdCompute.setText('Wait...')
561            self.cmdCompute.setEnabled(False)
562            d = threads.deferToThread(self.complete, inputs, self._update)
563            # Add deferred callback for call return
564            #d.addCallback(self.plot_1_2d)
565            d.addCallback(self.calculateComplete)
566            d.addErrback(self.calculateFailed)
567        except:
568            log_msg = "{}. stop".format(sys.exc_info()[1])
569            logging.info(log_msg)
570        return
571
572    def _update(self, value):
573        """
574        Copied from previous version
575        """
576        pass
577
578    def calculateFailed(self, reason):
579        """
580        """
581        print("Calculate Failed with:\n", reason)
582        pass
583
584    def calculateComplete(self, d):
585        """
586        Notify the main thread
587        """
588        self.calculationFinishedSignal.emit()
589
590    def complete(self, input, update=None):
591        """
592        Gen compute complete function
593        :Param input: input list [qx_data, qy_data, i_out]
594        """
595        out = numpy.empty(0)
596        for ind in range(len(input[0])):
597            if self.is_avg:
598                if ind % 1 == 0 and update is not None:
599                    # update()
600                    percentage = int(100.0 * float(ind) / len(input[0]))
601                    update(percentage)
602                    time.sleep(0.001)  # 0.1
603                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
604                outi = self.model.run(inputi)
605                out = numpy.append(out, outi)
606            else:
607                if ind % 50 == 0 and update is not None:
608                    percentage = int(100.0 * float(ind) / len(input[0]))
609                    update(percentage)
610                    time.sleep(0.001)
611                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
612                          input[2][ind:ind + 1]]
613                outi = self.model.runXY(inputi)
614                out = numpy.append(out, outi)
615        self.data_to_plot = out
616        logging.info('Gen computation completed.')
617        self.cmdCompute.setText('Compute')
618        self.cmdCompute.setEnabled(True)
619        return
620
621    def onSaveFile(self):
622        """Save data as .sld file"""
623        path = os.path.dirname(str(self.datafile))
624        default_name = os.path.join(path, 'sld_file')
625        kwargs = {
626            'parent': self,
627            'directory': default_name,
628            'filter': 'SLD file (*.sld)',
629            'options': QtWidgets.QFileDialog.DontUseNativeDialog}
630        # Query user for filename.
631        filename_tuple = QtWidgets.QFileDialog.getSaveFileName(**kwargs)
632        filename = filename_tuple[0]
633        if filename:
634            try:
635                _, extension = os.path.splitext(filename)
636                if not extension:
637                    filename = '.'.join((filename, 'sld'))
638                sas_gen.SLDReader().write(filename, self.sld_data)
639            except:
640                raise
641
642    def plot3d(self, has_arrow=False):
643        """ Generate 3D plot in real space with or without arrows """
644        self.write_new_values_from_gui()
645        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
646                                                            self.file_name)
647        if has_arrow:
648            graph_title += ' - Magnetic Vector as Arrow'
649
650        plot3D = Plotter3D(self, graph_title)
651        plot3D.plot(self.sld_data, has_arrow=has_arrow)
652        plot3D.show()
653        self.graph_num += 1
654
655    def plot_1_2d(self):
656        """ Generate 1D or 2D plot, called in Compute"""
657        if self.is_avg or self.is_avg is None:
658            data = Data1D(x=self.data.x, y=self.data_to_plot)
659            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
660                                                    int(self.graph_num))
661            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
662            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
663            plot1D = Plotter(self, quickplot=True)
664            plot1D.plot(data)
665            plot1D.show()
666            self.graph_num += 1
667            # TODO
668            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
669            return plot1D
670        else:
671            numpy.nan_to_num(self.data_to_plot)
672            data = Data2D(image=self.data_to_plot,
673                          qx_data=self.data.qx_data,
674                          qy_data=self.data.qy_data,
675                          q_data=self.data.q_data,
676                          xmin=self.data.xmin, xmax=self.data.ymax,
677                          ymin=self.data.ymin, ymax=self.data.ymax,
678                          err_image=self.data.err_data)
679            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
680                                                    int(self.graph_num))
681            plot2D = Plotter2D(self, quickplot=True)
682            plot2D.plot(data)
683            plot2D.show()
684            self.graph_num += 1
685            # TODO
686            print('TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED')
687            return plot2D
688
689
690class Plotter3DWidget(PlotterBase):
691    """
692    3D Plot widget for use with a QDialog
693    """
694    def __init__(self, parent=None, manager=None):
695        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
696
697    @property
698    def data(self):
699        return self._data
700
701    @data.setter
702    def data(self, data=None):
703        """ data setter """
704        self._data = data
705
706    def plot(self, data=None, has_arrow=False):
707        """
708        Plot 3D self._data
709        """
710        if not data:
711            return
712        self.data = data
713        #assert(self._data)
714        # Prepare and show the plot
715        self.showPlot(data=self.data, has_arrow=has_arrow)
716
717    def showPlot(self, data, has_arrow=False):
718        """
719        Render and show the current data
720        """
721        # If we don't have any data, skip.
722        if data is None:
723            return
724        # This import takes forever - place it here so the main UI starts faster
725        from mpl_toolkits.mplot3d import Axes3D
726        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
727                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
728        marker = ','
729        m_size = 2
730
731        pos_x = data.pos_x
732        pos_y = data.pos_y
733        pos_z = data.pos_z
734        sld_mx = data.sld_mx
735        sld_my = data.sld_my
736        sld_mz = data.sld_mz
737        pix_symbol = data.pix_symbol
738        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
739                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
740        is_nonzero = sld_tot > 0.0
741        is_zero = sld_tot == 0.0
742
743        if data.pix_type == 'atom':
744            marker = 'o'
745            m_size = 3.5
746
747        self.figure.clear()
748        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
749        ax = Axes3D(self.figure)
750        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
751        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
752        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
753
754        # I. Plot null points
755        if is_zero.any():
756            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
757                           marker, c="y", alpha=0.5, markeredgecolor='y',
758                           markersize=m_size)
759            pos_x = pos_x[is_nonzero]
760            pos_y = pos_y[is_nonzero]
761            pos_z = pos_z[is_nonzero]
762            sld_mx = sld_mx[is_nonzero]
763            sld_my = sld_my[is_nonzero]
764            sld_mz = sld_mz[is_nonzero]
765            pix_symbol = data.pix_symbol[is_nonzero]
766        # II. Plot selective points in color
767        other_color = numpy.ones(len(pix_symbol), dtype='bool')
768        for key in list(color_dic.keys()):
769            chosen_color = pix_symbol == key
770            if numpy.any(chosen_color):
771                other_color = other_color & (chosen_color!=True)
772                color = color_dic[key]
773                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
774                         pos_y[chosen_color], marker, c=color, alpha=0.5,
775                         markeredgecolor=color, markersize=m_size, label=key)
776        # III. Plot All others
777        if numpy.any(other_color):
778            a_name = ''
779            if data.pix_type == 'atom':
780                # Get atom names not in the list
781                a_names = [symb for symb in pix_symbol \
782                           if symb not in list(color_dic.keys())]
783                a_name = a_names[0]
784                for name in a_names:
785                    new_name = ", " + name
786                    if a_name.count(name) == 0:
787                        a_name += new_name
788            # plot in black
789            im = ax.plot(pos_x[other_color], pos_z[other_color],
790                         pos_y[other_color], marker, c="k", alpha=0.5,
791                         markeredgecolor="k", markersize=m_size, label=a_name)
792        if data.pix_type == 'atom':
793            ax.legend(loc='upper left', prop={'size': 10})
794        # IV. Draws atomic bond with grey lines if any
795        if data.has_conect:
796            for ind in range(len(data.line_x)):
797                im = ax.plot(data.line_x[ind], data.line_z[ind],
798                             data.line_y[ind], '-', lw=0.6, c="grey",
799                             alpha=0.3)
800        # V. Draws magnetic vectors
801        if has_arrow and len(pos_x) > 0:
802            def _draw_arrow(input=None, update=None):
803                # import moved here for performance reasons
804                from sas.qtgui.Plotting.Arrow3D import Arrow3D
805                """
806                draw magnetic vectors w/arrow
807                """
808                max_mx = max(numpy.fabs(sld_mx))
809                max_my = max(numpy.fabs(sld_my))
810                max_mz = max(numpy.fabs(sld_mz))
811                max_m = max(max_mx, max_my, max_mz)
812                try:
813                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
814                except:
815                    max_step = 0
816                if max_step <= 0:
817                    max_step = 5
818                try:
819                    if max_m != 0:
820                        unit_x2 = sld_mx / max_m
821                        unit_y2 = sld_my / max_m
822                        unit_z2 = sld_mz / max_m
823                        # 0.8 is for avoiding the color becomes white=(1,1,1))
824                        color_x = numpy.fabs(unit_x2 * 0.8)
825                        color_y = numpy.fabs(unit_y2 * 0.8)
826                        color_z = numpy.fabs(unit_z2 * 0.8)
827                        x2 = pos_x + unit_x2 * max_step
828                        y2 = pos_y + unit_y2 * max_step
829                        z2 = pos_z + unit_z2 * max_step
830                        x_arrow = numpy.column_stack((pos_x, x2))
831                        y_arrow = numpy.column_stack((pos_y, y2))
832                        z_arrow = numpy.column_stack((pos_z, z2))
833                        colors = numpy.column_stack((color_x, color_y, color_z))
834                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
835                                        colors, mutation_scale=10, lw=1,
836                                        arrowstyle="->", alpha=0.5)
837                        ax.add_artist(arrows)
838                except:
839                    pass
840                log_msg = "Arrow Drawing completed.\n"
841                logging.info(log_msg)
842            log_msg = "Arrow Drawing is in progress..."
843            logging.info(log_msg)
844
845            # Defer the drawing of arrows to another thread
846            d = threads.deferToThread(_draw_arrow, ax)
847
848        self.figure.canvas.resizing = False
849        self.figure.canvas.draw()
850
851    def createContextMenu(self):
852        """
853        Define common context menu and associated actions for the MPL widget
854        """
855        return
856
857    def createContextMenuQuick(self):
858        """
859        Define context menu and associated actions for the quickplot MPL widget
860        """
861        return
862
863
864class Plotter3D(QtWidgets.QDialog, Plotter3DWidget):
865    def __init__(self, parent=None, graph_title=''):
866        self.graph_title = graph_title
867        QtWidgets.QDialog.__init__(self)
868        Plotter3DWidget.__init__(self, manager=parent)
869        self.setWindowTitle(self.graph_title)
870
Note: See TracBrowser for help on using the repository browser.