source: sasview/src/sas/qtgui/Calculators/GenericScatteringCalculator.py @ 9909967

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 9909967 was 4d1eff2, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Fixed misbehaving calculator.

  • Property mode set to 100644
File size: 33.3 KB
Line 
1import sys
2import os
3import numpy
4import logging
5import time
6
7from PyQt4 import QtGui
8from PyQt4 import QtCore
9from twisted.internet import threads
10from mpl_toolkits.mplot3d import Axes3D
11
12import sas.qtgui.Utilities.GuiUtils as GuiUtils
13
14from sas.qtgui.Utilities.GenericReader import GenReader
15
16from sas.sascalc.dataloader.data_info import Detector
17from sas.sascalc.dataloader.data_info import Source
18from sas.sascalc.calculator import sas_gen
19
20from sas.qtgui.Plotting.Arrow3D import Arrow3D
21from sas.qtgui.Plotting.PlotterBase import PlotterBase
22from sas.qtgui.Plotting.Plotter2D import Plotter2D
23from sas.qtgui.Plotting.Plotter import Plotter
24from sas.qtgui.Plotting.PlotterData import Data1D
25from sas.qtgui.Plotting.PlotterData import Data2D
26
27# Local UI
28from UI.GenericScatteringCalculator import Ui_GenericScatteringCalculator
29
30_Q1D_MIN = 0.001
31
32
33class GenericScatteringCalculator(QtGui.QDialog, Ui_GenericScatteringCalculator):
34
35    trigger_plot_3d = QtCore.pyqtSignal()
36
37    def __init__(self, parent=None):
38        super(GenericScatteringCalculator, self).__init__()
39        self.setupUi(self)
40        self.manager = parent
41        self.communicator = self.manager.communicator()
42        self.model = sas_gen.GenSAS()
43        self.omf_reader = sas_gen.OMFReader()
44        self.sld_reader = sas_gen.SLDReader()
45        self.pdb_reader = sas_gen.PDBReader()
46        self.reader = None
47        self.sld_data = None
48
49        self.parameters = []
50        self.data = None
51        self.datafile = None
52        self.file_name = ''
53        self.ext = None
54        self.default_shape = str(self.cbShape.currentText())
55        self.is_avg = False
56        self.data_to_plot = None
57        self.graph_num = 1  # index for name of graph
58
59        # combox box
60        self.cbOptionsCalc.setVisible(False)
61
62        # push buttons
63        self.cmdClose.clicked.connect(self.accept)
64        self.cmdHelp.clicked.connect(self.onHelp)
65
66        self.cmdLoad.clicked.connect(self.loadFile)
67        self.cmdCompute.clicked.connect(self.onCompute)
68        self.cmdReset.clicked.connect(self.onReset)
69        self.cmdSave.clicked.connect(self.onSaveFile)
70
71        self.cmdDraw.clicked.connect(lambda: self.plot3d(has_arrow=True))
72        self.cmdDrawpoints.clicked.connect(lambda: self.plot3d(has_arrow=False))
73
74        # validators
75        # scale, volume and background must be positive
76        validat_regex_pos = QtCore.QRegExp('^[+]?([.]\d+|\d+([.]\d+)?)$')
77        self.txtScale.setValidator(QtGui.QRegExpValidator(validat_regex_pos,
78                                                          self.txtScale))
79        self.txtBackground.setValidator(QtGui.QRegExpValidator(
80            validat_regex_pos, self.txtBackground))
81        self.txtTotalVolume.setValidator(QtGui.QRegExpValidator(
82            validat_regex_pos, self.txtTotalVolume))
83
84        # fraction of spin up between 0 and 1
85        validat_regexbetween0_1 = QtCore.QRegExp('^(0(\.\d*)*|1(\.0+)?)$')
86        self.txtUpFracIn.setValidator(
87            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracIn))
88        self.txtUpFracOut.setValidator(
89            QtGui.QRegExpValidator(validat_regexbetween0_1, self.txtUpFracOut))
90
91        # 0 < Qmax <= 1000
92        validat_regex_q = QtCore.QRegExp('^1000$|^[+]?(\d{1,3}([.]\d+)?)$')
93        self.txtQxMax.setValidator(QtGui.QRegExpValidator(validat_regex_q,
94                                                          self.txtQxMax))
95        self.txtQxMax.textChanged.connect(self.check_value)
96
97        # 2 <= Qbin <= 1000
98        self.txtNoQBins.setValidator(QtGui.QRegExpValidator(validat_regex_q,
99                                                            self.txtNoQBins))
100        self.txtNoQBins.textChanged.connect(self.check_value)
101
102        # plots - 3D in real space
103        self.trigger_plot_3d.connect(lambda: self.plot3d(has_arrow=False))
104
105        # TODO the option Ellipsoid has not been implemented
106        self.cbShape.currentIndexChanged.connect(self.selectedshapechange)
107
108        # New font to display angstrom symbol
109        new_font = 'font-family: -apple-system, "Helvetica Neue", "Ubuntu";'
110        self.lblUnitSolventSLD.setStyleSheet(new_font)
111        self.lblUnitVolume.setStyleSheet(new_font)
112        self.lbl5.setStyleSheet(new_font)
113        self.lblUnitMx.setStyleSheet(new_font)
114        self.lblUnitMy.setStyleSheet(new_font)
115        self.lblUnitMz.setStyleSheet(new_font)
116        self.lblUnitNucl.setStyleSheet(new_font)
117        self.lblUnitx.setStyleSheet(new_font)
118        self.lblUnity.setStyleSheet(new_font)
119        self.lblUnitz.setStyleSheet(new_font)
120
121    def selectedshapechange(self):
122        """
123        TODO Temporary solution to display information about option 'Ellipsoid'
124        """
125        print "The option Ellipsoid has not been implemented yet."
126        self.communicator.statusBarUpdateSignal.emit(
127            "The option Ellipsoid has not been implemented yet.")
128
129    def loadFile(self):
130        """
131        Open menu to choose the datafile to load
132        Only extensions .SLD, .PDB, .OMF, .sld, .pdb, .omf
133        """
134        try:
135            self.datafile = QtGui.QFileDialog.getOpenFileName(
136                self, "Choose a file", "", "All Gen files (*.OMF *.omf) ;;"
137                                          "SLD files (*.SLD *.sld);;PDB files (*.pdb *.PDB);; "
138                                          "OMF files (*.OMF *.omf);; "
139                                          "All files (*.*)")
140            if self.datafile:
141                self.default_shape = str(self.cbShape.currentText())
142                self.file_name = os.path.basename(str(self.datafile))
143                self.ext = os.path.splitext(str(self.datafile))[1]
144                if self.ext in self.omf_reader.ext:
145                    loader = self.omf_reader
146                elif self.ext in self.sld_reader.ext:
147                    loader = self.sld_reader
148                elif self.ext in self.pdb_reader.ext:
149                    loader = self.pdb_reader
150                else:
151                    loader = None
152                # disable some entries depending on type of loaded file
153                # (according to documentation)
154                if self.ext.lower() in ['.sld', '.omf', '.pdb']:
155                    self.txtUpFracIn.setEnabled(False)
156                    self.txtUpFracOut.setEnabled(False)
157                    self.txtUpTheta.setEnabled(False)
158
159                if self.reader is not None and self.reader.isrunning():
160                    self.reader.stop()
161                self.cmdLoad.setEnabled(False)
162                self.cmdLoad.setText('Loading...')
163                self.communicator.statusBarUpdateSignal.emit(
164                    "Loading File {}".format(os.path.basename(
165                        str(self.datafile))))
166                self.reader = GenReader(path=str(self.datafile), loader=loader,
167                                        completefn=self.complete_loading,
168                                        updatefn=self.load_update)
169                self.reader.queue()
170        except (RuntimeError, IOError):
171            log_msg = "Generic SAS Calculator: %s" % sys.exc_value
172            logging.info(log_msg)
173            raise
174        return
175
176    def load_update(self):
177        """ Legacy function used in GenRead """
178        if self.reader.isrunning():
179            status_type = "progress"
180        else:
181            status_type = "stop"
182        logging.info(status_type)
183
184    def complete_loading(self, data=None):
185        """ Function used in GenRead"""
186        self.cbShape.setEnabled(False)
187        try:
188            is_pdbdata = False
189            self.txtData.setText(os.path.basename(str(self.datafile)))
190            self.is_avg = False
191            if self.ext in self.omf_reader.ext:
192                gen = sas_gen.OMF2SLD()
193                gen.set_data(data)
194                self.sld_data = gen.get_magsld()
195                self.check_units()
196            elif self.ext in self.sld_reader.ext:
197                self.sld_data = data
198            elif self.ext in self.pdb_reader.ext:
199                self.sld_data = data
200                is_pdbdata = True
201            # Display combobox of orientation only for pdb data
202            self.cbOptionsCalc.setVisible(is_pdbdata)
203            self.update_gui()
204        except IOError:
205            log_msg = "Loading Error: " \
206                      "This file format is not supported for GenSAS."
207            logging.info(log_msg)
208            raise
209        except ValueError:
210            log_msg = "Could not find any data"
211            logging.info(log_msg)
212            raise
213        logging.info("Load Complete")
214        self.cmdLoad.setEnabled(True)
215        self.cmdLoad.setText('Load')
216        self.trigger_plot_3d.emit()
217
218    def check_units(self):
219        """
220        Check if the units from the OMF file correspond to the default ones
221        displayed on the interface.
222        If not, modify the GUI with the correct unit
223        """
224        #  TODO: adopt the convention of font and symbol for the updated values
225        if sas_gen.OMFData().valueunit != 'A^(-2)':
226            value_unit = sas_gen.OMFData().valueunit
227            self.lbl_unitMx.setText(value_unit)
228            self.lbl_unitMy.setText(value_unit)
229            self.lbl_unitMz.setText(value_unit)
230            self.lbl_unitNucl.setText(value_unit)
231        if sas_gen.OMFData().meshunit != 'A':
232            mesh_unit = sas_gen.OMFData().meshunit
233            self.lbl_unitx.setText(mesh_unit)
234            self.lbl_unity.setText(mesh_unit)
235            self.lbl_unitz.setText(mesh_unit)
236            self.lbl_unitVolume.setText(mesh_unit+"^3")
237
238    def check_value(self):
239        """Check range of text edits for QMax and Number of Qbins """
240        text_edit = self.sender()
241        text_edit.setStyleSheet(
242            QtCore.QString.fromUtf8('background-color: rgb(255, 255, 255);'))
243        if text_edit.text():
244            value = float(str(text_edit.text()))
245            if text_edit == self.txtQxMax:
246                if value <= 0 or value > 1000:
247                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
248                        'background-color: rgb(255, 182, 193);'))
249                else:
250                    text_edit.setStyleSheet(QtCore.QString.fromUtf8(
251                        'background-color: rgb(255, 255, 255);'))
252            elif text_edit == self.txtNoQBins:
253                if value < 2 or value > 1000:
254                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
255                        'background-color: rgb(255, 182, 193);'))
256                else:
257                    self.txtNoQBins.setStyleSheet(QtCore.QString.fromUtf8(
258                        'background-color: rgb(255, 255, 255);'))
259
260    def update_gui(self):
261        """ Update the interface with values from loaded data """
262        self.model.set_is_avg(self.is_avg)
263        self.model.set_sld_data(self.sld_data)
264        self.model.params['total_volume'] = len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]
265
266        # add condition for activation of save button
267        self.cmdSave.setEnabled(True)
268
269        # activation of 3D plots' buttons (with and without arrows)
270        self.cmdDraw.setEnabled(self.sld_data is not None)
271        self.cmdDrawpoints.setEnabled(self.sld_data is not None)
272
273        self.txtScale.setText(str(self.model.params['scale']))
274        self.txtBackground.setText(str(self.model.params['background']))
275        self.txtSolventSLD.setText(str(self.model.params['solvent_SLD']))
276
277        # Volume to write to interface: npts x volume of first pixel
278        self.txtTotalVolume.setText(str(len(self.sld_data.sld_n)*self.sld_data.vol_pix[0]))
279
280        self.txtUpFracIn.setText(str(self.model.params['Up_frac_in']))
281        self.txtUpFracOut.setText(str(self.model.params['Up_frac_out']))
282        self.txtUpTheta.setText(str(self.model.params['Up_theta']))
283
284        self.txtNoPixels.setText(str(len(self.sld_data.sld_n)))
285        self.txtNoPixels.setEnabled(False)
286
287        list_parameters = ['sld_mx', 'sld_my', 'sld_mz', 'sld_n', 'xnodes',
288                           'ynodes', 'znodes', 'xstepsize', 'ystepsize',
289                           'zstepsize']
290        list_gui_button = [self.txtMx, self.txtMy, self.txtMz, self.txtNucl,
291                           self.txtXnodes, self.txtYnodes, self.txtZnodes,
292                           self.txtXstepsize, self.txtYstepsize,
293                           self.txtZstepsize]
294
295        # Fill right hand side of GUI
296        for indx, item in enumerate(list_parameters):
297            if getattr(self.sld_data, item) is None:
298                list_gui_button[indx].setText('NaN')
299            else:
300                value = getattr(self.sld_data, item)
301                if isinstance(value, numpy.ndarray):
302                    item_for_gui = str(GuiUtils.formatNumber(numpy.average(value), True))
303                else:
304                    item_for_gui = str(GuiUtils.formatNumber(value, True))
305                list_gui_button[indx].setText(item_for_gui)
306
307        # Enable / disable editing of right hand side of GUI
308        for indx, item in enumerate(list_parameters):
309            if indx < 4:
310                # this condition only applies to Mx,y,z and Nucl
311                value = getattr(self.sld_data, item)
312                enable = self.sld_data.pix_type == 'pixel' \
313                         and numpy.min(value) == numpy.max(value)
314            else:
315                enable = not self.sld_data.is_data
316            list_gui_button[indx].setEnabled(enable)
317
318    def write_new_values_from_gui(self):
319        """
320        update parameters using modified inputs from GUI
321        used before computing
322        """
323        if self.txtScale.isModified():
324            self.model.params['scale'] = float(self.txtScale.text())
325
326        if self.txtBackground.isModified():
327            self.model.params['background'] = float(self.txtBackground.text())
328
329        if self.txtSolventSLD.isModified():
330            self.model.params['solvent_SLD'] = float(self.txtSolventSLD.text())
331
332        # Different condition for total volume to get correct volume after
333        # applying set_sld_data in compute
334        if self.txtTotalVolume.isModified() \
335                or self.model.params['total_volume'] != float(self.txtTotalVolume.text()):
336            self.model.params['total_volume'] = float(self.txtTotalVolume.text())
337
338        if self.txtUpFracIn.isModified():
339            self.model.params['Up_frac_in'] = float(self.txtUpFracIn.text())
340
341        if self.txtUpFracOut.isModified():
342            self.model.params['Up_frac_out'] = float(self.txtUpFracOut.text())
343
344        if self.txtUpTheta.isModified():
345            self.model.params['Up_theta'] = float(self.txtUpTheta.text())
346
347        if self.txtMx.isModified():
348            self.sld_data.sld_mx = float(self.txtMx.text())*\
349                                   numpy.ones(len(self.sld_data.sld_mx))
350
351        if self.txtMy.isModified():
352            self.sld_data.sld_my = float(self.txtMy.text())*\
353                                   numpy.ones(len(self.sld_data.sld_my))
354
355        if self.txtMz.isModified():
356            self.sld_data.sld_mz = float(self.txtMz.text())*\
357                                   numpy.ones(len(self.sld_data.sld_mz))
358
359        if self.txtNucl.isModified():
360            self.sld_data.sld_n = float(self.txtNucl.text())*\
361                                  numpy.ones(len(self.sld_data.sld_n))
362
363        if self.txtXnodes.isModified():
364            self.sld_data.xnodes = int(self.txtXnodes.text())
365
366        if self.txtYnodes.isModified():
367            self.sld_data.ynodes = int(self.txtYnodes.text())
368
369        if self.txtZnodes.isModified():
370            self.sld_data.znodes = int(self.txtZnodes.text())
371
372        if self.txtXstepsize.isModified():
373            self.sld_data.xstepsize = float(self.txtXstepsize.text())
374
375        if self.txtYstepsize.isModified():
376            self.sld_data.ystepsize = float(self.txtYstepsize.text())
377
378        if self.txtZstepsize.isModified():
379            self.sld_data.zstepsize = float(self.txtZstepsize.text())
380
381        if self.cbOptionsCalc.isVisible():
382            self.is_avg = (self.cbOptionsCalc.currentIndex() == 1)
383
384    def onHelp(self):
385        """
386        Bring up the Generic Scattering calculator Documentation whenever
387        the HELP button is clicked.
388        Calls Documentation Window with the path of the location within the
389        documentation tree (after /doc/ ....".
390        """
391        try:
392            location = GuiUtils.HELP_DIRECTORY_LOCATION + \
393                       "/user/sasgui/perspectives/calculator/sas_calculator_help.html"
394            self.manager._helpView.load(QtCore.QUrl(location))
395            self.manager._helpView.show()
396        except AttributeError:
397            # No manager defined - testing and standalone runs
398            pass
399
400    def onReset(self):
401        """ Reset the inputs of textEdit to default values """
402        try:
403            # reset values in textedits
404            self.txtUpFracIn.setText("1.0")
405            self.txtUpFracOut.setText("1.0")
406            self.txtUpTheta.setText("0.0")
407            self.txtBackground.setText("0.0")
408            self.txtScale.setText("1.0")
409            self.txtSolventSLD.setText("0.0")
410            self.txtTotalVolume.setText("216000.0")
411            self.txtNoQBins.setText("50")
412            self.txtQxMax.setText("0.3")
413            self.txtNoPixels.setText("1000")
414            self.txtMx.setText("0")
415            self.txtMy.setText("0")
416            self.txtMz.setText("0")
417            self.txtNucl.setText("6.97e-06")
418            self.txtXnodes.setText("10")
419            self.txtYnodes.setText("10")
420            self.txtZnodes.setText("10")
421            self.txtXstepsize.setText("6")
422            self.txtYstepsize.setText("6")
423            self.txtZstepsize.setText("6")
424            # reset Load button and textedit
425            self.txtData.setText('Default SLD Profile')
426            self.cmdLoad.setEnabled(True)
427            self.cmdLoad.setText('Load')
428            # reset option for calculation
429            self.cbOptionsCalc.setCurrentIndex(0)
430            self.cbOptionsCalc.setVisible(False)
431            # reset shape button
432            self.cbShape.setCurrentIndex(0)
433            self.cbShape.setEnabled(True)
434            # reset compute button
435            self.cmdCompute.setText('Compute')
436            self.cmdCompute.setEnabled(True)
437            # TODO reload default data set
438            self._create_default_sld_data()
439
440        finally:
441            pass
442
443    def _create_default_2d_data(self):
444        """
445        Copied from previous version
446        Create 2D data by default
447        :warning: This data is never plotted.
448        """
449        self.qmax_x = float(self.txtQxMax.text())
450        self.npts_x = int(self.txtNoQBins.text())
451        self.data = Data2D()
452        self.data.is_data = False
453        # # Default values
454        self.data.detector.append(Detector())
455        index = len(self.data.detector) - 1
456        self.data.detector[index].distance = 8000  # mm
457        self.data.source.wavelength = 6  # A
458        self.data.detector[index].pixel_size.x = 5  # mm
459        self.data.detector[index].pixel_size.y = 5  # mm
460        self.data.detector[index].beam_center.x = self.qmax_x
461        self.data.detector[index].beam_center.y = self.qmax_x
462        xmax = self.qmax_x
463        xmin = -xmax
464        ymax = self.qmax_x
465        ymin = -ymax
466        qstep = self.npts_x
467
468        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
469        y = numpy.linspace(start=ymin, stop=ymax, num=qstep, endpoint=True)
470        # use data info instead
471        new_x = numpy.tile(x, (len(y), 1))
472        new_y = numpy.tile(y, (len(x), 1))
473        new_y = new_y.swapaxes(0, 1)
474        # all data require now in 1d array
475        qx_data = new_x.flatten()
476        qy_data = new_y.flatten()
477        q_data = numpy.sqrt(qx_data * qx_data + qy_data * qy_data)
478        # set all True (standing for unmasked) as default
479        mask = numpy.ones(len(qx_data), dtype=bool)
480        self.data.source = Source()
481        self.data.data = numpy.ones(len(mask))
482        self.data.err_data = numpy.ones(len(mask))
483        self.data.qx_data = qx_data
484        self.data.qy_data = qy_data
485        self.data.q_data = q_data
486        self.data.mask = mask
487        # store x and y bin centers in q space
488        self.data.x_bins = x
489        self.data.y_bins = y
490        # max and min taking account of the bin sizes
491        self.data.xmin = xmin
492        self.data.xmax = xmax
493        self.data.ymin = ymin
494        self.data.ymax = ymax
495
496    def _create_default_sld_data(self):
497        """
498        Copied from previous version
499        Making default sld-data
500        """
501        sld_n_default = 6.97e-06  # what is this number??
502        omfdata = sas_gen.OMFData()
503        omf2sld = sas_gen.OMF2SLD()
504        omf2sld.set_data(omfdata, self.default_shape)
505        self.sld_data = omf2sld.output
506        self.sld_data.is_data = False
507        self.sld_data.filename = "Default SLD Profile"
508        self.sld_data.set_sldn(sld_n_default)
509
510    def _create_default_1d_data(self):
511        """
512        Copied from previous version
513        Create 1D data by default
514        :warning: This data is never plotted.
515                    residuals.x = data_copy.x[index]
516            residuals.dy = numpy.ones(len(residuals.y))
517            residuals.dx = None
518            residuals.dxl = None
519            residuals.dxw = None
520        """
521        self.qmax_x = float(self.txtQxMax.text())
522        self.npts_x = int(self.txtNoQBins.text())
523        #  Default values
524        xmax = self.qmax_x
525        xmin = self.qmax_x * _Q1D_MIN
526        qstep = self.npts_x
527        x = numpy.linspace(start=xmin, stop=xmax, num=qstep, endpoint=True)
528        # store x and y bin centers in q space
529        y = numpy.ones(len(x))
530        dy = numpy.zeros(len(x))
531        dx = numpy.zeros(len(x))
532        self.data = Data1D(x=x, y=y)
533        self.data.dx = dx
534        self.data.dy = dy
535
536    def onCompute(self):
537        """
538        Copied from previous version
539        Execute the computation of I(qx, qy)
540        """
541        # Set default data when nothing loaded yet
542        if self.sld_data is None:
543            self._create_default_sld_data()
544        try:
545            self.model.set_sld_data(self.sld_data)
546            self.write_new_values_from_gui()
547            if self.is_avg or self.is_avg is None:
548                self._create_default_1d_data()
549                i_out = numpy.zeros(len(self.data.y))
550                inputs = [self.data.x, [], i_out]
551            else:
552                self._create_default_2d_data()
553                i_out = numpy.zeros(len(self.data.data))
554                inputs = [self.data.qx_data, self.data.qy_data, i_out]
555            logging.info("Computation is in progress...")
556            self.cmdCompute.setText('Wait...')
557            self.cmdCompute.setEnabled(False)
558            d = threads.deferToThread(self.complete, inputs, self._update)
559            # Add deferred callback for call return
560            d.addCallback(self.plot_1_2d)
561        except:
562            log_msg = "{}. stop".format(sys.exc_value)
563            logging.info(log_msg)
564        return
565
566    def _update(self, value):
567        """
568        Copied from previous version
569        """
570        pass
571
572    def complete(self, input, update=None):
573        """
574        Gen compute complete function
575        :Param input: input list [qx_data, qy_data, i_out]
576        """
577        out = numpy.empty(0)
578        for ind in range(len(input[0])):
579            if self.is_avg:
580                if ind % 1 == 0 and update is not None:
581                    # update()
582                    percentage = int(100.0 * float(ind) / len(input[0]))
583                    update(percentage)
584                    time.sleep(0.001)  # 0.1
585                inputi = [input[0][ind:ind + 1], [], input[2][ind:ind + 1]]
586                outi = self.model.run(inputi)
587                out = numpy.append(out, outi)
588            else:
589                if ind % 50 == 0 and update is not None:
590                    percentage = int(100.0 * float(ind) / len(input[0]))
591                    update(percentage)
592                    time.sleep(0.001)
593                inputi = [input[0][ind:ind + 1], input[1][ind:ind + 1],
594                          input[2][ind:ind + 1]]
595                outi = self.model.runXY(inputi)
596                out = numpy.append(out, outi)
597        self.data_to_plot = out
598        logging.info('Gen computation completed.')
599        self.cmdCompute.setText('Compute')
600        self.cmdCompute.setEnabled(True)
601        return
602
603    def onSaveFile(self):
604        """Save data as .sld file"""
605        path = os.path.dirname(str(self.datafile))
606        default_name = os.path.join(path, 'sld_file')
607        kwargs = {
608            'parent': self,
609            'directory': default_name,
610            'filter': 'SLD file (*.sld)',
611            'options': QtGui.QFileDialog.DontUseNativeDialog}
612        # Query user for filename.
613        filename = str(QtGui.QFileDialog.getSaveFileName(**kwargs))
614        if filename:
615            try:
616                if os.path.splitext(filename)[1].lower() == '.sld':
617                    sas_gen.SLDReader().write(filename, self.sld_data)
618                else:
619                    sas_gen.SLDReader().write('.'.join((filename, 'sld')),
620                                              self.sld_data)
621            except:
622                raise
623
624    def plot3d(self, has_arrow=False):
625        """ Generate 3D plot in real space with or without arrows """
626        self.write_new_values_from_gui()
627        graph_title = " Graph {}: {} 3D SLD Profile".format(self.graph_num,
628                                                            self.file_name)
629        if has_arrow:
630            graph_title += ' - Magnetic Vector as Arrow'
631
632        plot3D = Plotter3D(self, graph_title)
633        plot3D.plot(self.sld_data, has_arrow=has_arrow)
634        plot3D.show()
635        self.graph_num += 1
636
637    def plot_1_2d(self, d):
638        """ Generate 1D or 2D plot, called in Compute"""
639        if self.is_avg or self.is_avg is None:
640            data = Data1D(x=self.data.x, y=self.data_to_plot)
641            data.title = "GenSAS {}  #{} 1D".format(self.file_name,
642                                                    int(self.graph_num))
643            data.xaxis('\\rm{Q_{x}}', '\AA^{-1}')
644            data.yaxis('\\rm{Intensity}', 'cm^{-1}')
645            plot1D = Plotter(self)
646            plot1D.plot(data)
647            plot1D.show()
648            self.graph_num += 1
649            # TODO
650            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
651            return plot1D
652        else:
653            numpy.nan_to_num(self.data_to_plot)
654            data = Data2D(image=self.data_to_plot,
655                          qx_data=self.data.qx_data,
656                          qy_data=self.data.qy_data,
657                          q_data=self.data.q_data,
658                          xmin=self.data.xmin, xmax=self.data.ymax,
659                          ymin=self.data.ymin, ymax=self.data.ymax,
660                          err_image=self.data.err_data)
661            data.title = "GenSAS {}  #{} 2D".format(self.file_name,
662                                                    int(self.graph_num))
663            plot2D = Plotter2D(self)
664            plot2D.plot(data)
665            plot2D.show()
666            self.graph_num += 1
667            # TODO
668            print 'TRANSFER OF DATA TO MAIN PANEL TO BE IMPLEMENTED'
669            return plot2D
670
671
672class Plotter3DWidget(PlotterBase):
673    """
674    3D Plot widget for use with a QDialog
675    """
676    def __init__(self, parent=None, manager=None):
677        super(Plotter3DWidget, self).__init__(parent,  manager=manager)
678
679    @property
680    def data(self):
681        return self._data
682
683    @data.setter
684    def data(self, data=None):
685        """ data setter """
686        self._data = data
687
688    def plot(self, data=None, has_arrow=False):
689        """
690        Plot 3D self._data
691        """
692        if not data:
693            return
694        self.data = data
695        #assert(self._data)
696        # Prepare and show the plot
697        self.showPlot(data=self.data, has_arrow=has_arrow)
698
699    def showPlot(self, data, has_arrow=False):
700        """
701        Render and show the current data
702        """
703        # If we don't have any data, skip.
704        if data is None:
705            return
706        color_dic = {'H': 'blue', 'D': 'purple', 'N': 'orange',
707                     'O': 'red', 'C': 'green', 'P': 'cyan', 'Other': 'k'}
708        marker = ','
709        m_size = 2
710
711        pos_x = data.pos_x
712        pos_y = data.pos_y
713        pos_z = data.pos_z
714        sld_mx = data.sld_mx
715        sld_my = data.sld_my
716        sld_mz = data.sld_mz
717        pix_symbol = data.pix_symbol
718        sld_tot = numpy.fabs(sld_mx) + numpy.fabs(sld_my) + \
719                  numpy.fabs(sld_mz) + numpy.fabs(data.sld_n)
720        is_nonzero = sld_tot > 0.0
721        is_zero = sld_tot == 0.0
722
723        if data.pix_type == 'atom':
724            marker = 'o'
725            m_size = 3.5
726
727        self.figure.clear()
728        self.figure.subplots_adjust(left=0.1, right=.8, bottom=.1)
729        ax = Axes3D(self.figure)
730        ax.set_xlabel('x ($\A{}$)'.format(data.pos_unit))
731        ax.set_ylabel('z ($\A{}$)'.format(data.pos_unit))
732        ax.set_zlabel('y ($\A{}$)'.format(data.pos_unit))
733
734        # I. Plot null points
735        if is_zero.any():
736            im = ax.plot(pos_x[is_zero], pos_z[is_zero], pos_y[is_zero],
737                           marker, c="y", alpha=0.5, markeredgecolor='y',
738                           markersize=m_size)
739            pos_x = pos_x[is_nonzero]
740            pos_y = pos_y[is_nonzero]
741            pos_z = pos_z[is_nonzero]
742            sld_mx = sld_mx[is_nonzero]
743            sld_my = sld_my[is_nonzero]
744            sld_mz = sld_mz[is_nonzero]
745            pix_symbol = data.pix_symbol[is_nonzero]
746        # II. Plot selective points in color
747        other_color = numpy.ones(len(pix_symbol), dtype='bool')
748        for key in color_dic.keys():
749            chosen_color = pix_symbol == key
750            if numpy.any(chosen_color):
751                other_color = other_color & (chosen_color!=True)
752                color = color_dic[key]
753                im = ax.plot(pos_x[chosen_color], pos_z[chosen_color],
754                         pos_y[chosen_color], marker, c=color, alpha=0.5,
755                         markeredgecolor=color, markersize=m_size, label=key)
756        # III. Plot All others
757        if numpy.any(other_color):
758            a_name = ''
759            if data.pix_type == 'atom':
760                # Get atom names not in the list
761                a_names = [symb for symb in pix_symbol \
762                           if symb not in color_dic.keys()]
763                a_name = a_names[0]
764                for name in a_names:
765                    new_name = ", " + name
766                    if a_name.count(name) == 0:
767                        a_name += new_name
768            # plot in black
769            im = ax.plot(pos_x[other_color], pos_z[other_color],
770                         pos_y[other_color], marker, c="k", alpha=0.5,
771                         markeredgecolor="k", markersize=m_size, label=a_name)
772        if data.pix_type == 'atom':
773            ax.legend(loc='upper left', prop={'size': 10})
774        # IV. Draws atomic bond with grey lines if any
775        if data.has_conect:
776            for ind in range(len(data.line_x)):
777                im = ax.plot(data.line_x[ind], data.line_z[ind],
778                             data.line_y[ind], '-', lw=0.6, c="grey",
779                             alpha=0.3)
780        # V. Draws magnetic vectors
781        if has_arrow and len(pos_x) > 0:
782            def _draw_arrow(input=None, update=None):
783                """
784                draw magnetic vectors w/arrow
785                """
786                max_mx = max(numpy.fabs(sld_mx))
787                max_my = max(numpy.fabs(sld_my))
788                max_mz = max(numpy.fabs(sld_mz))
789                max_m = max(max_mx, max_my, max_mz)
790                try:
791                    max_step = max(data.xstepsize, data.ystepsize, data.zstepsize)
792                except:
793                    max_step = 0
794                if max_step <= 0:
795                    max_step = 5
796                try:
797                    if max_m != 0:
798                        unit_x2 = sld_mx / max_m
799                        unit_y2 = sld_my / max_m
800                        unit_z2 = sld_mz / max_m
801                        # 0.8 is for avoiding the color becomes white=(1,1,1))
802                        color_x = numpy.fabs(unit_x2 * 0.8)
803                        color_y = numpy.fabs(unit_y2 * 0.8)
804                        color_z = numpy.fabs(unit_z2 * 0.8)
805                        x2 = pos_x + unit_x2 * max_step
806                        y2 = pos_y + unit_y2 * max_step
807                        z2 = pos_z + unit_z2 * max_step
808                        x_arrow = numpy.column_stack((pos_x, x2))
809                        y_arrow = numpy.column_stack((pos_y, y2))
810                        z_arrow = numpy.column_stack((pos_z, z2))
811                        colors = numpy.column_stack((color_x, color_y, color_z))
812                        arrows = Arrow3D(self.figure, x_arrow, z_arrow, y_arrow,
813                                        colors, mutation_scale=10, lw=1,
814                                        arrowstyle="->", alpha=0.5)
815                        ax.add_artist(arrows)
816                except:
817                    pass
818                log_msg = "Arrow Drawing completed.\n"
819                logging.info(log_msg)
820            log_msg = "Arrow Drawing is in progress..."
821            logging.info(log_msg)
822
823            # Defer the drawing of arrows to another thread
824            d = threads.deferToThread(_draw_arrow, ax)
825
826        self.figure.canvas.resizing = False
827        self.figure.canvas.draw()
828
829
830class Plotter3D(QtGui.QDialog, Plotter3DWidget):
831    def __init__(self, parent=None, graph_title=''):
832        self.graph_title = graph_title
833        QtGui.QDialog.__init__(self)
834        Plotter3DWidget.__init__(self, manager=parent)
835        self.setWindowTitle(self.graph_title)
836
Note: See TracBrowser for help on using the repository browser.