source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @ 98b13f72

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 98b13f72 was 98b13f72, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

More smearing functionality

  • Property mode set to 100755
File size: 41.0 KB
Line 
1import sys
2import json
3import os
4import numpy
5from collections import defaultdict
6from itertools import izip
7
8import logging
9import traceback
10from twisted.internet import threads
11
12from PyQt4 import QtGui
13from PyQt4 import QtCore
14
15from sasmodels import generate
16from sasmodels import modelinfo
17from sasmodels.sasview_model import load_standard_models
18from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit
19from sas.sasgui.perspectives.fitting.fit_thread import FitThread
20
21from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
22from sas.sasgui.guiframe.dataFitting import Data1D
23from sas.sasgui.guiframe.dataFitting import Data2D
24import sas.qtgui.Utilities.GuiUtils as GuiUtils
25from sas.sasgui.perspectives.fitting.model_thread import Calc1D
26from sas.sasgui.perspectives.fitting.model_thread import Calc2D
27from sas.sasgui.perspectives.fitting.utils import get_weight
28
29from UI.FittingWidgetUI import Ui_FittingWidgetUI
30from sas.qtgui.Perspectives.Fitting.FittingLogic import FittingLogic
31from sas.qtgui.Perspectives.Fitting import FittingUtilities
32from SmearingWidget import SmearingWidget
33
34TAB_MAGNETISM = 4
35TAB_POLY = 3
36CATEGORY_DEFAULT = "Choose category..."
37CATEGORY_STRUCTURE = "Structure Factor"
38STRUCTURE_DEFAULT = "None"
39QMIN_DEFAULT = 0.0005
40QMAX_DEFAULT = 0.5
41NPTS_DEFAULT = 50
42
43class FittingWidget(QtGui.QWidget, Ui_FittingWidgetUI):
44    """
45    Main widget for selecting form and structure factor models
46    """
47    def __init__(self, parent=None, data=None, id=1):
48
49        super(FittingWidget, self).__init__()
50
51        # Necessary globals
52        self.parent = parent
53        # SasModel is loaded
54        self.model_is_loaded = False
55        # Data[12]D passed and set
56        self.data_is_loaded = False
57        # Current SasModel in view
58        self.kernel_module = None
59        # Current SasModel view dimension
60        self.is2D = False
61        # Current SasModel is multishell
62        self.model_has_shells = False
63        # Utility variable to enable unselectable option in category combobox
64        self._previous_category_index = 0
65        # Utility variable for multishell display
66        self._last_model_row = 0
67        # Dictionary of {model name: model class} for the current category
68        self.models = {}
69        # Parameters to fit
70        self.parameters_to_fit = None
71        # Weight radio box group
72        self.weightingGroup = QtGui.QButtonGroup()
73
74        # Which tab is this widget displayed in?
75        self.tab_id = id
76
77        # Which shell is being currently displayed?
78        self.current_shell_displayed = 0
79        self.has_error_column = False
80
81        # Range parameters
82        self.q_range_min = QMIN_DEFAULT
83        self.q_range_max = QMAX_DEFAULT
84        self.npts = NPTS_DEFAULT
85
86        # Main Data[12]D holder
87        self.logic = FittingLogic(data=data)
88
89        # Main GUI setup up
90        self.setupUi(self)
91        self.setWindowTitle("Fitting")
92        self.communicate = self.parent.communicate
93
94        # Smearing widget
95        layout = QtGui.QGridLayout()
96        self.smearing_widget = SmearingWidget(self)
97        layout.addWidget(self.smearing_widget) 
98        #self.tabFitting.removeTab(2)
99        self.tabFitting.insertTab(2, self.smearing_widget, "Resolution")
100
101        # Define bold font for use in various controls
102        self.boldFont=QtGui.QFont()
103        self.boldFont.setBold(True)
104
105        # Set data label
106        self.label.setFont(self.boldFont)
107        self.label.setText("No data loaded")
108        self.lblFilename.setText("")
109
110        # Set the main models
111        # We can't use a single model here, due to restrictions on flattening
112        # the model tree with subclassed QAbstractProxyModel...
113        self._model_model = QtGui.QStandardItemModel()
114        self._poly_model = QtGui.QStandardItemModel()
115        self._magnet_model = QtGui.QStandardItemModel()
116
117        # Param model displayed in param list
118        self.lstParams.setModel(self._model_model)
119        self.readCategoryInfo()
120        self.model_parameters = None
121        self.lstParams.setAlternatingRowColors(True)
122
123        # Poly model displayed in poly list
124        self.lstPoly.setModel(self._poly_model)
125        self.setPolyModel()
126        self.setTableProperties(self.lstPoly)
127
128        # Magnetism model displayed in magnetism list
129        self.lstMagnetic.setModel(self._magnet_model)
130        self.setMagneticModel()
131        self.setTableProperties(self.lstMagnetic)
132
133        # Defaults for the structure factors
134        self.setDefaultStructureCombo()
135
136        # Make structure factor and model CBs disabled
137        self.disableModelCombo()
138        self.disableStructureCombo()
139
140        # Generate the category list for display
141        category_list = sorted(self.master_category_dict.keys())
142        self.cbCategory.addItem(CATEGORY_DEFAULT)
143        self.cbCategory.addItems(category_list)
144        self.cbCategory.addItem(CATEGORY_STRUCTURE)
145        self.cbCategory.setCurrentIndex(0)
146
147        # Connect signals to controls
148        self.initializeSignals()
149
150        # Initial control state
151        self.initializeControls()
152
153        self._index = None
154        if data is not None:
155            self.data = data
156
157    @property
158    def data(self):
159        return self.logic.data
160
161    @data.setter
162    def data(self, value):
163        """ data setter """
164        assert isinstance(value, QtGui.QStandardItem)
165        # _index contains the QIndex with data
166        self._index = value
167
168        # Update logics with data items
169        self.logic.data = GuiUtils.dataFromItem(value)
170
171        # Overwrite data type descriptor
172        self.is2D = True if isinstance(self.logic.data, Data2D) else False
173
174        self.data_is_loaded = True
175
176        # Enable/disable UI components
177        self.setEnablementOnDataLoad()
178
179    def setEnablementOnDataLoad(self):
180        """
181        Enable/disable various UI elements based on data loaded
182        """
183        # Tag along functionality
184        self.label.setText("Data loaded from: ")
185        self.lblFilename.setText(self.logic.data.filename)
186        self.updateQRange()
187        self.cmdFit.setEnabled(True)
188        self.boxWeighting.setEnabled(True)
189        self.cmdMaskEdit.setEnabled(True)
190        # Switch off txtNpts related controls
191        self.txtNpts.setEnabled(False)
192        self.txtNptsFit.setEnabled(False)
193        self.chkLogData.setEnabled(False)
194        # Switch off Data2D control
195        self.chk2DView.setEnabled(False)
196        self.chk2DView.setVisible(False)
197        self.chkMagnetism.setEnabled(True)
198
199        # Weighting controls
200        if self.is2D:
201            if self.logic.data.err_data is None or\
202                    numpy.all(err == 1 for err in self.logic.data.err_data) or \
203                    not numpy.any(self.logic.data.err_data):
204                self.rbWeighting2.setEnabled(False)
205                self.rbWeighting1.setChecked(True)
206            else:
207                self.rbWeighting2.setEnabled(True)
208                self.rbWeighting2.setChecked(True)
209        else:
210            if self.logic.data.dy is None or\
211                    numpy.all(self.logic.data.dy == 1) or\
212                    not numpy.any(self.logic.data.dy):
213                self.rbWeighting2.setEnabled(False)
214                self.rbWeighting1.setChecked(True)
215            else:
216                self.rbWeighting2.setEnabled(True)
217                self.rbWeighting2.setChecked(True)
218
219        # Smearing tab
220        self.smearing_widget.updateSmearing(self.data)
221
222    def acceptsData(self):
223        """ Tells the caller this widget can accept new dataset """
224        return not self.data_is_loaded
225
226    def disableModelCombo(self):
227        """ Disable the combobox """
228        self.cbModel.setEnabled(False)
229        self.lblModel.setEnabled(False)
230
231    def enableModelCombo(self):
232        """ Enable the combobox """
233        self.cbModel.setEnabled(True)
234        self.lblModel.setEnabled(True)
235
236    def disableStructureCombo(self):
237        """ Disable the combobox """
238        self.cbStructureFactor.setEnabled(False)
239        self.lblStructure.setEnabled(False)
240
241    def enableStructureCombo(self):
242        """ Enable the combobox """
243        self.cbStructureFactor.setEnabled(True)
244        self.lblStructure.setEnabled(True)
245
246    def togglePoly(self, isChecked):
247        """ Enable/disable the polydispersity tab """
248        self.tabFitting.setTabEnabled(TAB_POLY, isChecked)
249
250    def toggleMagnetism(self, isChecked):
251        """ Enable/disable the magnetism tab """
252        self.tabFitting.setTabEnabled(TAB_MAGNETISM, isChecked)
253
254    def toggle2D(self, isChecked):
255        """ Enable/disable the controls dependent on 1D/2D data instance """
256        self.chkMagnetism.setEnabled(isChecked)
257        self.is2D = isChecked
258        # Reload the current model
259        if self.kernel_module:
260            self.onSelectModel()
261
262    def toggleLogData(self, isChecked):
263        """ Toggles between log and linear data sets """
264        pass
265
266    def initializeControls(self):
267        """
268        Set initial control enablement
269        """
270        self.cmdFit.setEnabled(False)
271        self.cmdPlot.setEnabled(True)
272        self.cmdComputePoints.setVisible(False) # probably redundant
273        self.chkPolydispersity.setEnabled(True)
274        self.chkPolydispersity.setCheckState(False)
275        self.chk2DView.setEnabled(True)
276        self.chk2DView.setCheckState(False)
277        self.chkMagnetism.setEnabled(False)
278        self.chkMagnetism.setCheckState(False)
279        # Tabs
280        self.tabFitting.setTabEnabled(TAB_POLY, False)
281        self.tabFitting.setTabEnabled(TAB_MAGNETISM, False)
282        self.lblChi2Value.setText("---")
283        # Group boxes
284        self.boxWeighting.setEnabled(False)
285        self.cmdMaskEdit.setEnabled(False)
286        # Button groups
287        self.weightingGroup.addButton(self.rbWeighting1)
288        self.weightingGroup.addButton(self.rbWeighting2)
289        self.weightingGroup.addButton(self.rbWeighting3)
290        self.weightingGroup.addButton(self.rbWeighting4)
291        # Smearing tab
292        self.smearing_widget.updateSmearing(self.data)
293
294    def initializeSignals(self):
295        """
296        Connect GUI element signals
297        """
298        # Comboboxes
299        self.cbStructureFactor.currentIndexChanged.connect(self.onSelectStructureFactor)
300        self.cbCategory.currentIndexChanged.connect(self.onSelectCategory)
301        self.cbModel.currentIndexChanged.connect(self.onSelectModel)
302        # Checkboxes
303        self.chk2DView.toggled.connect(self.toggle2D)
304        self.chkPolydispersity.toggled.connect(self.togglePoly)
305        self.chkMagnetism.toggled.connect(self.toggleMagnetism)
306        self.chkLogData.toggled.connect(self.toggleLogData)
307        # Buttons
308        self.cmdFit.clicked.connect(self.onFit)
309        self.cmdPlot.clicked.connect(self.onPlot)
310        self.cmdMaskEdit.clicked.connect(self.onMaskEdit)
311        self.cmdReset.clicked.connect(self.onReset)
312        # Line edits
313        self.txtNpts.editingFinished.connect(self.onNpts)
314        self.txtMinRange.editingFinished.connect(self.onMinRange)
315        self.txtMaxRange.editingFinished.connect(self.onMaxRange)
316        # Button groups
317        self.weightingGroup.buttonClicked.connect(self.onWeightingChoice)
318
319        # Respond to change in parameters from the UI
320        self._model_model.itemChanged.connect(self.updateParamsFromModel)
321        self._poly_model.itemChanged.connect(self.onPolyModelChange)
322        # TODO after the poly_model prototype accepted
323        #self._magnet_model.itemChanged.connect(self.onMagneticModelChange)
324
325    def onSelectModel(self):
326        """
327        Respond to select Model from list event
328        """
329        model = str(self.cbModel.currentText())
330
331        # Reset structure factor
332        self.cbStructureFactor.setCurrentIndex(0)
333
334        # Reset parameters to fit
335        self.parameters_to_fit = None
336        self.has_error_column = False
337
338        # SasModel -> QModel
339        self.SASModelToQModel(model)
340
341        if self.data_is_loaded:
342            self.calculateQGridForModel()
343        else:
344            # Create default datasets if no data passed
345            self.createDefaultDataset()
346
347    def onSelectStructureFactor(self):
348        """
349        Select Structure Factor from list
350        """
351        model = str(self.cbModel.currentText())
352        category = str(self.cbCategory.currentText())
353        structure = str(self.cbStructureFactor.currentText())
354        if category == CATEGORY_STRUCTURE:
355            model = None
356        self.SASModelToQModel(model, structure_factor=structure)
357
358    def onSelectCategory(self):
359        """
360        Select Category from list
361        """
362        category = str(self.cbCategory.currentText())
363        # Check if the user chose "Choose category entry"
364        if category == CATEGORY_DEFAULT:
365            # if the previous category was not the default, keep it.
366            # Otherwise, just return
367            if self._previous_category_index != 0:
368                # We need to block signals, or else state changes on perceived unchanged conditions
369                self.cbCategory.blockSignals(True)
370                self.cbCategory.setCurrentIndex(self._previous_category_index)
371                self.cbCategory.blockSignals(False)
372            return
373
374        if category == CATEGORY_STRUCTURE:
375            self.disableModelCombo()
376            self.enableStructureCombo()
377            self._model_model.clear()
378            return
379
380        # Safely clear and enable the model combo
381        self.cbModel.blockSignals(True)
382        self.cbModel.clear()
383        self.cbModel.blockSignals(False)
384        self.enableModelCombo()
385        self.disableStructureCombo()
386
387        self._previous_category_index = self.cbCategory.currentIndex()
388        # Retrieve the list of models
389        model_list = self.master_category_dict[category]
390        models = []
391        # Populate the models combobox
392        self.cbModel.addItems(sorted([model for (model, _) in model_list]))
393
394    def onWeightingChoice(self, button):
395        """
396        Update weighting in the fit state
397        """
398        button_id = button.group().checkedId()
399        button_id = abs(button_id + 2)
400        #self.fitPage.weighting = button_id
401        print button_id
402
403    def onPolyModelChange(self, item):
404        """
405        Callback method for updating the main model and sasmodel
406        parameters with the GUI values in the polydispersity view
407        """
408        model_column = item.column()
409        model_row = item.row()
410        name_index = self._poly_model.index(model_row, 0)
411        # Extract changed value. Assumes proper validation by QValidator/Delegate
412        # Checkbox in column 0
413        if model_column == 0:
414            value = item.checkState()
415        else:
416            try:
417                value = float(item.text())
418            except ValueError:
419                # Can't be converted properly, bring back the old value and exit
420                return
421
422        parameter_name = str(self._poly_model.data(name_index).toPyObject()) # "distribution of sld" etc.
423        if "Distribution of" in parameter_name:
424            parameter_name = parameter_name[16:]
425        property_name = str(self._poly_model.headerData(model_column, 1).toPyObject()) # Value, min, max, etc.
426        # print "%s(%s) => %d" % (parameter_name, property_name, value)
427
428        # Update the sasmodel
429        #self.kernel_module.params[parameter_name] = value
430
431        # Reload the main model - may not be required if no variable is shown in main view
432        #model = str(self.cbModel.currentText())
433        #self.SASModelToQModel(model)
434
435        pass # debug anchor
436
437    def onFit(self):
438        """
439        Perform fitting on the current data
440        """
441        fitter = Fit()
442
443        # Data going in
444        data = self.logic.data
445        model = self.kernel_module
446        qmin = self.q_range_min
447        qmax = self.q_range_max
448        params_to_fit = self.parameters_to_fit
449
450        # Potential weights added
451        self.addWeightingToData(data)
452
453        # Potential smearing added
454        smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state()
455
456        # These should be updating somehow?
457        fit_id = 0
458        constraints = []
459        smearer = None
460        page_id = [210]
461        handler = None
462        batch_inputs = {}
463        batch_outputs = {}
464        list_page_id = [page_id]
465        #---------------------------------
466
467        # Parameterize the fitter
468        fitter.set_model(model, fit_id, params_to_fit, data=data,
469                         constraints=constraints)
470        fitter.set_data(data=data, id=fit_id, smearer=smearer, qmin=qmin,
471                        qmax=qmax)
472        fitter.select_problem_for_fit(id=fit_id, value=1)
473
474        fitter.fitter_id = page_id
475
476        # Create the fitting thread, based on the fitter
477        calc_fit = FitThread(handler=handler,
478                             fn=[fitter],
479                             batch_inputs=batch_inputs,
480                             batch_outputs=batch_outputs,
481                             page_id=list_page_id,
482                             updatefn=self.updateFit,
483                             completefn=None)
484
485        # start the trhrhread
486        calc_thread = threads.deferToThread(calc_fit.compute)
487        calc_thread.addCallback(self.fitComplete)
488
489        #disable the Fit button
490        self.cmdFit.setText('Calculating...')
491        self.communicate.statusBarUpdateSignal.emit('Fitting started...')
492        self.cmdFit.setEnabled(False)
493
494    def updateFit(self):
495        """
496        """
497        print "UPDATE FIT"
498        pass
499
500    def fitComplete(self, result):
501        """
502        Receive and display fitting results
503        "result" is a tuple of actual result list and the fit time in seconds
504        """
505        #re-enable the Fit button
506        self.cmdFit.setText("Fit")
507        self.cmdFit.setEnabled(True)
508
509        assert result is not None
510
511        res_list = result[0]
512        res = res_list[0]
513        if res.fitness is None or \
514            not numpy.isfinite(res.fitness) or \
515            numpy.any(res.pvec == None) or \
516            not numpy.all(numpy.isfinite(res.pvec)):
517            msg = "Fitting did not converge!!!"
518            self.communicate.statusBarUpdateSignal.emit(msg)
519            logging.error(msg)
520            return
521
522        elapsed = result[1]
523        msg = "Fitting completed successfully in: %s s.\n" % GuiUtils.formatNumber(elapsed)
524        self.communicate.statusBarUpdateSignal.emit(msg)
525
526        fitness = res.fitness
527        param_list = res.param_list
528        param_values = res.pvec
529        param_stderr = res.stderr
530        params_and_errors = zip(param_values, param_stderr)
531        param_dict = dict(izip(param_list, params_and_errors))
532
533        # Dictionary of fitted parameter: value, error
534        # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)}
535        self.updateModelFromList(param_dict)
536
537        # update charts
538        self.onPlot()
539
540        # Read only value - we can get away by just printing it here
541        chi2_repr = GuiUtils.formatNumber(fitness, high=True)
542        self.lblChi2Value.setText(chi2_repr)
543
544    def iterateOverModel(self, func):
545        """
546        Take func and throw it inside the model row loop
547        """
548        #assert isinstance(func, function)
549        for row_i in xrange(self._model_model.rowCount()):
550            func(row_i)
551
552    def updateModelFromList(self, param_dict):
553        """
554        Update the model with new parameters, create the errors column
555        """
556        assert isinstance(param_dict, dict)
557        if not dict:
558            return
559
560        def updateFittedValues(row_i):
561            # Utility function for main model update
562            # internal so can use closure for param_dict
563            param_name = str(self._model_model.item(row_i, 0).text())
564            if param_name not in param_dict.keys():
565                return
566            # modify the param value
567            param_repr = GuiUtils.formatNumber(param_dict[param_name][0], high=True)
568            self._model_model.item(row_i, 1).setText(param_repr)
569            if self.has_error_column:
570                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
571                self._model_model.item(row_i, 2).setText(error_repr)
572
573        def createErrorColumn(row_i):
574            # Utility function for error column update
575            item = QtGui.QStandardItem()
576            for param_name in param_dict.keys():
577                if str(self._model_model.item(row_i, 0).text()) != param_name:
578                    continue
579                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
580                item.setText(error_repr)
581            error_column.append(item)
582
583        self.iterateOverModel(updateFittedValues)
584
585        if self.has_error_column:
586            return
587
588        error_column = []
589        self.iterateOverModel(createErrorColumn)
590
591        # switch off reponse to model change
592        self._model_model.blockSignals(True)
593        self._model_model.insertColumn(2, error_column)
594        self._model_model.blockSignals(False)
595        FittingUtilities.addErrorHeadersToModel(self._model_model)
596        # Adjust the table cells width.
597        # TODO: find a way to dynamically adjust column width while resized expanding
598        self.lstParams.resizeColumnToContents(0)
599        self.lstParams.resizeColumnToContents(4)
600        self.lstParams.resizeColumnToContents(5)
601        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
602
603        self.has_error_column = True
604
605    def onPlot(self):
606        """
607        Plot the current set of data
608        """
609        if self.data is None :
610            self.createDefaultDataset()
611        self.calculateQGridForModel()
612
613    def onNpts(self):
614        """
615        Callback for number of points line edit update
616        """
617        # assumes type/value correctness achieved with QValidator
618        try:
619            self.npts = int(self.txtNpts.text())
620        except ValueError:
621            # TODO
622            # This will return the old value to model/view and return
623            # notifying the user about format available.
624            pass
625        # Force redisplay
626        if self.model_is_loaded:
627            self.onPlot()
628
629    def onMinRange(self):
630        """
631        Callback for minimum range of points line edit update
632        """
633        # assumes type/value correctness achieved with QValidator
634        try:
635            self.q_range_min = float(self.txtMinRange.text())
636        except ValueError:
637            # TODO
638            # This will return the old value to model/view and return
639            # notifying the user about format available.
640            return
641        # set Q range labels on the main tab
642        #self.lblMinRangeDef.setText(str(self.q_range_min))
643        if self.model_is_loaded:
644            self.onPlot()
645
646    def onMaxRange(self):
647        """
648        Callback for maximum range of points line edit update
649        """
650        # assumes type/value correctness achieved with QValidator
651        try:
652            self.q_range_max = float(self.txtMaxRange.text())
653        except:
654            pass
655        # set Q range labels on the main tab
656        self.lblMaxRangeDef.setText(str(self.q_range_max))
657        if self.model_is_loaded:
658            self.onPlot()
659
660    def onMaskEdit(self):
661        """
662        Callback for running the mask editor
663        """
664        pass
665
666    def onReset(self):
667        """
668        Callback for resetting qmin/qmax
669        """
670        pass
671
672    def setDefaultStructureCombo(self):
673        """
674        Fill in the structure factors combo box with defaults
675        """
676        structure_factor_list = self.master_category_dict.pop(CATEGORY_STRUCTURE)
677        factors = [factor[0] for factor in structure_factor_list]
678        factors.insert(0, STRUCTURE_DEFAULT)
679        self.cbStructureFactor.clear()
680        self.cbStructureFactor.addItems(sorted(factors))
681
682    def createDefaultDataset(self):
683        """
684        Generate default Dataset 1D/2D for the given model
685        """
686        # Create default datasets if no data passed
687        if self.is2D:
688            qmax = self.q_range_max/numpy.sqrt(2)
689            qstep = self.npts
690            self.logic.createDefault2dData(qmax, qstep, self.tab_id)
691        else:
692            interval = numpy.linspace(start=self.q_range_min, stop=self.q_range_max,
693                        num=self.npts, endpoint=True)
694            self.logic.createDefault1dData(interval, self.tab_id)
695
696    def readCategoryInfo(self):
697        """
698        Reads the categories in from file
699        """
700        self.master_category_dict = defaultdict(list)
701        self.by_model_dict = defaultdict(list)
702        self.model_enabled_dict = defaultdict(bool)
703
704        categorization_file = CategoryInstaller.get_user_file()
705        if not os.path.isfile(categorization_file):
706            categorization_file = CategoryInstaller.get_default_file()
707        with open(categorization_file, 'rb') as cat_file:
708            self.master_category_dict = json.load(cat_file)
709            self.regenerateModelDict()
710
711        # Load the model dict
712        models = load_standard_models()
713        for model in models:
714            self.models[model.name] = model
715
716    def regenerateModelDict(self):
717        """
718        Regenerates self.by_model_dict which has each model name as the
719        key and the list of categories belonging to that model
720        along with the enabled mapping
721        """
722        self.by_model_dict = defaultdict(list)
723        for category in self.master_category_dict:
724            for (model, enabled) in self.master_category_dict[category]:
725                self.by_model_dict[model].append(category)
726                self.model_enabled_dict[model] = enabled
727
728    def addBackgroundToModel(self, model):
729        """
730        Adds background parameter with default values to the model
731        """
732        assert isinstance(model, QtGui.QStandardItemModel)
733        checked_list = ['background', '0.001', '-inf', 'inf', '1/cm']
734        FittingUtilities.addCheckedListToModel(model, checked_list)
735
736    def addScaleToModel(self, model):
737        """
738        Adds scale parameter with default values to the model
739        """
740        assert isinstance(model, QtGui.QStandardItemModel)
741        checked_list = ['scale', '1.0', '0.0', 'inf', '']
742        FittingUtilities.addCheckedListToModel(model, checked_list)
743
744    def addWeightingToData(self, data):
745        """
746        Adds weighting contribution to fitting data
747        """
748        # Check the state of the Weighting radio buttons
749        button_id = self.weightingGroup.checkedId()
750        # Cast the id to a valid index
751        button_id = abs(button_id + 2)
752        if button_id == 0:
753            # No weight added
754            return
755        # Send original data for weighting
756        weight = get_weight(data=data, is2d=self.is2D, flag=button_id)
757        if self.is2D:
758            data.err_data = weight
759        else:
760            data.dy = weight
761
762    def updateQRange(self):
763        """
764        Updates Q Range display
765        """
766        if self.data_is_loaded:
767            self.q_range_min, self.q_range_max, self.npts = self.logic.computeDataRange()
768        # set Q range labels on the main tab
769        self.lblMinRangeDef.setText(str(self.q_range_min))
770        self.lblMaxRangeDef.setText(str(self.q_range_max))
771        # set Q range labels on the options tab
772        self.txtMaxRange.setText(str(self.q_range_max))
773        self.txtMinRange.setText(str(self.q_range_min))
774        self.txtNpts.setText(str(self.npts))
775        self.txtNptsFit.setText(str(self.npts))
776
777    def SASModelToQModel(self, model_name, structure_factor=None):
778        """
779        Setting model parameters into table based on selected category
780        """
781        # TODO - modify for structure factor-only choice
782
783        # Crete/overwrite model items
784        self._model_model.clear()
785
786        kernel_module = generate.load_kernel_module(model_name)
787        self.model_parameters = modelinfo.make_parameter_table(getattr(kernel_module, 'parameters', []))
788
789        # Instantiate the current sasmodel
790        self.kernel_module = self.models[model_name]()
791
792        # Explicitly add scale and background with default values
793        self.addScaleToModel(self._model_model)
794        self.addBackgroundToModel(self._model_model)
795
796        # Update the QModel
797        new_rows = FittingUtilities.addParametersToModel(self.model_parameters, self.is2D)
798        for row in new_rows:
799            self._model_model.appendRow(row)
800        # Update the counter used for multishell display
801        self._last_model_row = self._model_model.rowCount()
802
803        FittingUtilities.addHeadersToModel(self._model_model)
804
805        # Add structure factor
806        if structure_factor is not None and structure_factor != "None":
807            structure_module = generate.load_kernel_module(structure_factor)
808            structure_parameters = modelinfo.make_parameter_table(getattr(structure_module, 'parameters', []))
809            new_rows = FittingUtilities.addSimpleParametersToModel(structure_parameters, self.is2D)
810            for row in new_rows:
811                self._model_model.appendRow(row)
812            # Update the counter used for multishell display
813            self._last_model_row = self._model_model.rowCount()
814        else:
815            self.addStructureFactor()
816
817        # Multishell models need additional treatment
818        self.addExtraShells()
819
820        # Add polydispersity to the model
821        self.setPolyModel()
822        # Add magnetic parameters to the model
823        self.setMagneticModel()
824
825        # Adjust the table cells width
826        self.lstParams.resizeColumnToContents(0)
827        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
828
829        # Now we claim the model has been loaded
830        self.model_is_loaded = True
831
832        # Update Q Ranges
833        self.updateQRange()
834
835    def updateParamsFromModel(self, item):
836        """
837        Callback method for updating the sasmodel parameters with the GUI values
838        """
839        model_column = item.column()
840
841        if model_column == 0:
842            self.checkboxSelected(item)
843            return
844
845        model_row = item.row()
846        name_index = self._model_model.index(model_row, 0)
847
848        # Extract changed value. Assumes proper validation by QValidator/Delegate
849        value = float(item.text())
850        parameter_name = str(self._model_model.data(name_index).toPyObject()) # sld, background etc.
851        property_name = str(self._model_model.headerData(1, model_column).toPyObject()) # Value, min, max, etc.
852
853        self.kernel_module.params[parameter_name] = value
854
855        # min/max to be changed in self.kernel_module.details[parameter_name] = ['Ang', 0.0, inf]
856        # magnetic params in self.kernel_module.details['M0:parameter_name'] = value
857        # multishell params in self.kernel_module.details[??] = value
858
859        # Force the chart update when actual parameters changed
860        if model_column == 1:
861            self.onPlot()
862
863    def checkboxSelected(self, item):
864        # Assure we're dealing with checkboxes
865        if not item.isCheckable():
866            return
867        status = item.checkState()
868
869        def isChecked(row):
870            return self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked
871
872        def isCheckable(row):
873            return self._model_model.item(row, 0).isCheckable()
874
875        # If multiple rows selected - toggle all of them, filtering uncheckable
876        rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() if isCheckable(s.row())]
877
878        # Switch off signaling from the model to avoid recursion
879        self._model_model.blockSignals(True)
880        # Convert to proper indices and set requested enablement
881        items = [self._model_model.item(row, 0).setCheckState(status) for row in rows]
882        self._model_model.blockSignals(False)
883
884        # update the list of parameters to fit
885        self.parameters_to_fit = [str(self._model_model.item(row_index, 0).text())
886                                  for row_index in xrange(self._model_model.rowCount())
887                                  if isChecked(row_index)]
888
889    def nameForFittedData(self, name):
890        """
891        Generate name for the current fit
892        """
893        if self.is2D:
894            name += "2d"
895        name = "M%i [%s]" % (self.tab_id, name)
896        return name
897
898    def createNewIndex(self, fitted_data):
899        """
900        Create a model or theory index with passed Data1D/Data2D
901        """
902        if self.data_is_loaded:
903            if not fitted_data.name:
904                name = self.nameForFittedData(self.data.filename)
905                fitted_data.title = name
906                fitted_data.name = name
907                fitted_data.filename = name
908                fitted_data.symbol = "Line"
909            self.updateModelIndex(fitted_data)
910        else:
911            name = self.nameForFittedData(self.kernel_module.name)
912            fitted_data.title = name
913            fitted_data.name = name
914            fitted_data.filename = name
915            fitted_data.symbol = "Line"
916            self.createTheoryIndex(fitted_data)
917
918    def updateModelIndex(self, fitted_data):
919        """
920        Update a QStandardModelIndex containing model data
921        """
922        if fitted_data.name is None:
923            name = self.nameForFittedData(self.logic.data.filename)
924            fitted_data.title = name
925            fitted_data.name = name
926        else:
927            name = fitted_data.name
928        # Make this a line if no other defined
929        if hasattr(fitted_data, 'symbol') and fitted_data.symbol is None:
930            fitted_data.symbol = 'Line'
931        # Notify the GUI manager so it can update the main model in DataExplorer
932        GuiUtils.updateModelItemWithPlot(self._index, QtCore.QVariant(fitted_data), name)
933
934    def createTheoryIndex(self, fitted_data):
935        """
936        Create a QStandardModelIndex containing model data
937        """
938        if fitted_data.name is None:
939            name = self.nameForFittedData(self.kernel_module.name)
940            fitted_data.title = name
941            fitted_data.name = name
942            fitted_data.filename = name
943        else:
944            name = fitted_data.name
945        # Notify the GUI manager so it can create the theory model in DataExplorer
946        new_item = GuiUtils.createModelItemWithPlot(QtCore.QVariant(fitted_data), name=name)
947        self.communicate.updateTheoryFromPerspectiveSignal.emit(new_item)
948
949    def methodCalculateForData(self):
950        '''return the method for data calculation'''
951        return Calc1D if isinstance(self.data, Data1D) else Calc2D
952
953    def methodCompleteForData(self):
954        '''return the method for result parsin on calc complete '''
955        return self.complete1D if isinstance(self.data, Data1D) else self.complete2D
956
957    def calculateQGridForModel(self):
958        """
959        Prepare the fitting data object, based on current ModelModel
960        """
961        # Awful API to a backend method.
962        method = self.methodCalculateForData()(data=self.data,
963                              model=self.kernel_module,
964                              page_id=0,
965                              qmin=self.q_range_min,
966                              qmax=self.q_range_max,
967                              smearer=None,
968                              state=None,
969                              weight=None,
970                              fid=None,
971                              toggle_mode_on=False,
972                              completefn=None,
973                              update_chisqr=True,
974                              exception_handler=self.calcException,
975                              source=None)
976
977        calc_thread = threads.deferToThread(method.compute)
978        calc_thread.addCallback(self.methodCompleteForData())
979
980    def complete1D(self, return_data):
981        """
982        Plot the current 1D data
983        """
984        fitted_plot = self.logic.new1DPlot(return_data, self.tab_id)
985        self.calculateResiduals(fitted_plot)
986
987    def complete2D(self, return_data):
988        """
989        Plot the current 2D data
990        """
991        fitted_data = self.logic.new2DPlot(return_data)
992        self.calculateResiduals(fitted_data)
993
994    def calculateResiduals(self, fitted_data):
995        """
996        Calculate and print Chi2 and display chart of residuals
997        """
998        # Create a new index for holding data
999        fitted_data.symbol = "Line"
1000        self.createNewIndex(fitted_data)
1001        # Calculate difference between return_data and logic.data
1002        chi2 = FittingUtilities.calculateChi2(fitted_data, self.logic.data)
1003        # Update the control
1004        chi2_repr = "---" if chi2 is None else GuiUtils.formatNumber(chi2, high=True)
1005        self.lblChi2Value.setText(chi2_repr)
1006
1007        # Plot residuals if actual data
1008        if self.data_is_loaded:
1009            residuals_plot = FittingUtilities.plotResiduals(self.data, fitted_data)
1010            residuals_plot.id = "Residual " + residuals_plot.id
1011            self.createNewIndex(residuals_plot)
1012            self.communicate.plotUpdateSignal.emit([residuals_plot])
1013
1014        self.communicate.plotUpdateSignal.emit([fitted_data])
1015
1016    def calcException(self, etype, value, tb):
1017        """
1018        Something horrible happened in the deferred.
1019        """
1020        logging.error("".join(traceback.format_exception(etype, value, tb)))
1021
1022    def setTableProperties(self, table):
1023        """
1024        Setting table properties
1025        """
1026        # Table properties
1027        table.verticalHeader().setVisible(False)
1028        table.setAlternatingRowColors(True)
1029        table.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
1030        table.setSelectionBehavior(QtGui.QAbstractItemView.SelectRows)
1031        table.resizeColumnsToContents()
1032
1033        # Header
1034        header = table.horizontalHeader()
1035        header.setResizeMode(QtGui.QHeaderView.ResizeToContents)
1036
1037        header.ResizeMode(QtGui.QHeaderView.Interactive)
1038        # Resize column 0 and 6 to content
1039        header.setResizeMode(0, QtGui.QHeaderView.ResizeToContents)
1040        header.setResizeMode(6, QtGui.QHeaderView.ResizeToContents)
1041
1042    def setPolyModel(self):
1043        """
1044        Set polydispersity values
1045        """
1046        if not self.model_parameters:
1047            return
1048        self._poly_model.clear()
1049        for row, param in enumerate(self.model_parameters.form_volume_parameters):
1050            # Counters should not be included
1051            if not param.polydisperse:
1052                continue
1053
1054            # Potential multishell params
1055            checked_list = ["Distribution of "+param.name, str(param.default),
1056                            str(param.limits[0]), str(param.limits[1]),
1057                            "35", "3", ""]
1058            FittingUtilities.addCheckedListToModel(self._poly_model, checked_list)
1059
1060            #TODO: Need to find cleaner way to input functions
1061            func = QtGui.QComboBox()
1062            func.addItems(['rectangle', 'array', 'lognormal', 'gaussian', 'schulz',])
1063            func_index = self.lstPoly.model().index(row, 6)
1064            self.lstPoly.setIndexWidget(func_index, func)
1065
1066        FittingUtilities.addPolyHeadersToModel(self._poly_model)
1067
1068    def setMagneticModel(self):
1069        """
1070        Set magnetism values on model
1071        """
1072        if not self.model_parameters:
1073            return
1074        self._magnet_model.clear()
1075        for param in self.model_parameters.call_parameters:
1076            if param.type != "magnetic":
1077                continue
1078            checked_list = [param.name,
1079                            str(param.default),
1080                            str(param.limits[0]),
1081                            str(param.limits[1]),
1082                            param.units]
1083            FittingUtilities.addCheckedListToModel(self._magnet_model, checked_list)
1084
1085        FittingUtilities.addHeadersToModel(self._magnet_model)
1086
1087    def addStructureFactor(self):
1088        """
1089        Add structure factors to the list of parameters
1090        """
1091        if self.kernel_module.is_form_factor:
1092            self.enableStructureCombo()
1093        else:
1094            self.disableStructureCombo()
1095
1096    def addExtraShells(self):
1097        """
1098        Add a combobox for multiple shell display
1099        """
1100        param_name, param_length = FittingUtilities.getMultiplicity(self.model_parameters)
1101
1102        if param_length == 0:
1103            return
1104
1105        # cell 1: variable name
1106        item1 = QtGui.QStandardItem(param_name)
1107
1108        func = QtGui.QComboBox()
1109        # Available range of shells displayed in the combobox
1110        func.addItems([str(i) for i in xrange(param_length+1)])
1111
1112        # Respond to index change
1113        func.currentIndexChanged.connect(self.modifyShellsInList)
1114
1115        # cell 2: combobox
1116        item2 = QtGui.QStandardItem()
1117        self._model_model.appendRow([item1, item2])
1118
1119        # Beautify the row:  span columns 2-4
1120        shell_row = self._model_model.rowCount()
1121        shell_index = self._model_model.index(shell_row-1, 1)
1122
1123        self.lstParams.setIndexWidget(shell_index, func)
1124        self._last_model_row = self._model_model.rowCount()
1125
1126        # Set the index to the state-kept value
1127        func.setCurrentIndex(self.current_shell_displayed
1128                             if self.current_shell_displayed < func.count() else 0)
1129
1130    def modifyShellsInList(self, index):
1131        """
1132        Add/remove additional multishell parameters
1133        """
1134        # Find row location of the combobox
1135        last_row = self._last_model_row
1136        remove_rows = self._model_model.rowCount() - last_row
1137
1138        if remove_rows > 1:
1139            self._model_model.removeRows(last_row, remove_rows)
1140
1141        FittingUtilities.addShellsToModel(self.model_parameters, self._model_model, index)
1142        self.current_shell_displayed = index
1143
Note: See TracBrowser for help on using the repository browser.