source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @ 2add354

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 2add354 was 2add354, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Code review fixes for SASVIEW-273

  • Property mode set to 100644
File size: 41.3 KB
Line 
1import sys
2import json
3import os
4import numpy as np
5from collections import defaultdict
6from itertools import izip
7
8import logging
9import traceback
10from twisted.internet import threads
11
12from PyQt4 import QtGui
13from PyQt4 import QtCore
14from PyQt4 import QtWebKit
15
16from sasmodels import generate
17from sasmodels import modelinfo
18from sasmodels.sasview_model import load_standard_models
19from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit
20from sas.sasgui.perspectives.fitting.fit_thread import FitThread
21
22from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
23from sas.sasgui.guiframe.dataFitting import Data1D
24from sas.sasgui.guiframe.dataFitting import Data2D
25import sas.qtgui.Utilities.GuiUtils as GuiUtils
26from sas.sasgui.perspectives.fitting.model_thread import Calc1D
27from sas.sasgui.perspectives.fitting.model_thread import Calc2D
28from sas.sasgui.perspectives.fitting.utils import get_weight
29
30from UI.FittingWidgetUI import Ui_FittingWidgetUI
31from sas.qtgui.Perspectives.Fitting.FittingLogic import FittingLogic
32from sas.qtgui.Perspectives.Fitting import FittingUtilities
33from SmearingWidget import SmearingWidget
34from OptionsWidget import OptionsWidget
35from FitPage import FitPage
36
37TAB_MAGNETISM = 4
38TAB_POLY = 3
39CATEGORY_DEFAULT = "Choose category..."
40CATEGORY_STRUCTURE = "Structure Factor"
41STRUCTURE_DEFAULT = "None"
42
43class FittingWidget(QtGui.QWidget, Ui_FittingWidgetUI):
44    """
45    Main widget for selecting form and structure factor models
46    """
47    def __init__(self, parent=None, data=None, id=1):
48
49        super(FittingWidget, self).__init__()
50
51        # Necessary globals
52        self.parent = parent
53        # SasModel is loaded
54        self.model_is_loaded = False
55        # Data[12]D passed and set
56        self.data_is_loaded = False
57        # Current SasModel in view
58        self.kernel_module = None
59        # Current SasModel view dimension
60        self.is2D = False
61        # Current SasModel is multishell
62        self.model_has_shells = False
63        # Utility variable to enable unselectable option in category combobox
64        self._previous_category_index = 0
65        # Utility variable for multishell display
66        self._last_model_row = 0
67        # Dictionary of {model name: model class} for the current category
68        self.models = {}
69        # Parameters to fit
70        self.parameters_to_fit = None
71        # Fit options
72        self.q_range_min = 0.005
73        self.q_range_max = 0.1
74        self.npts = 25
75        self.log_points = False
76        self.weighting = 0
77        self.chi2 = None
78
79        # Data for chosen model
80        self.model_data = None
81
82        # Which tab is this widget displayed in?
83        self.tab_id = id
84
85        # Which shell is being currently displayed?
86        self.current_shell_displayed = 0
87        self.has_error_column = False
88
89        # Main Data[12]D holder
90        self.logic = FittingLogic(data=data)
91
92        # Main GUI setup up
93        self.setupUi(self)
94        self.setWindowTitle("Fitting")
95        self.communicate = self.parent.communicate
96
97        # Options widget
98        layout = QtGui.QGridLayout()
99        self.options_widget = OptionsWidget(self, self.logic)
100        layout.addWidget(self.options_widget) 
101        self.tabOptions.setLayout(layout)
102
103        # Smearing widget
104        layout = QtGui.QGridLayout()
105        self.smearing_widget = SmearingWidget(self)
106        layout.addWidget(self.smearing_widget) 
107        self.tabResolution.setLayout(layout)
108
109        # Define bold font for use in various controls
110        self.boldFont=QtGui.QFont()
111        self.boldFont.setBold(True)
112
113        # Set data label
114        self.label.setFont(self.boldFont)
115        self.label.setText("No data loaded")
116        self.lblFilename.setText("")
117
118        # Set the main models
119        # We can't use a single model here, due to restrictions on flattening
120        # the model tree with subclassed QAbstractProxyModel...
121        self._model_model = QtGui.QStandardItemModel()
122        self._poly_model = QtGui.QStandardItemModel()
123        self._magnet_model = QtGui.QStandardItemModel()
124
125        # Param model displayed in param list
126        self.lstParams.setModel(self._model_model)
127        self.readCategoryInfo()
128        self.model_parameters = None
129        self.lstParams.setAlternatingRowColors(True)
130        stylesheet = """
131            QTreeView{
132                alternate-background-color: #f6fafb;
133                background: #e8f4fc;
134            }
135        """
136        self.lstParams.setStyleSheet(stylesheet)
137        self.lstParams.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
138        self.lstParams.customContextMenuRequested.connect(self.showModelDescription)
139
140        # Poly model displayed in poly list
141        self.lstPoly.setModel(self._poly_model)
142        self.setPolyModel()
143        self.setTableProperties(self.lstPoly)
144
145        # Magnetism model displayed in magnetism list
146        self.lstMagnetic.setModel(self._magnet_model)
147        self.setMagneticModel()
148        self.setTableProperties(self.lstMagnetic)
149
150        # Defaults for the structure factors
151        self.setDefaultStructureCombo()
152
153        # Make structure factor and model CBs disabled
154        self.disableModelCombo()
155        self.disableStructureCombo()
156
157        # Generate the category list for display
158        category_list = sorted(self.master_category_dict.keys())
159        self.cbCategory.addItem(CATEGORY_DEFAULT)
160        self.cbCategory.addItems(category_list)
161        self.cbCategory.addItem(CATEGORY_STRUCTURE)
162        self.cbCategory.setCurrentIndex(0)
163
164        # Connect signals to controls
165        self.initializeSignals()
166
167        # Initial control state
168        self.initializeControls()
169
170        # Display HTML content
171        self.helpView = QtWebKit.QWebView()
172
173        self._index = None
174        if data is not None:
175            self.data = data
176
177    def close(self):
178        """
179        Remember to kill off things on exit
180        """
181        self.helpView.close()
182        del self.helpView
183
184    @property
185    def data(self):
186        return self.logic.data
187
188    @data.setter
189    def data(self, value):
190        """ data setter """
191        assert isinstance(value, QtGui.QStandardItem)
192        # _index contains the QIndex with data
193        self._index = value
194
195        # Update logics with data items
196        self.logic.data = GuiUtils.dataFromItem(value)
197
198        # Overwrite data type descriptor
199        self.is2D = True if isinstance(self.logic.data, Data2D) else False
200
201        self.data_is_loaded = True
202
203        # Enable/disable UI components
204        self.setEnablementOnDataLoad()
205
206    def setEnablementOnDataLoad(self):
207        """
208        Enable/disable various UI elements based on data loaded
209        """
210        # Tag along functionality
211        self.label.setText("Data loaded from: ")
212        self.lblFilename.setText(self.logic.data.filename)
213        self.updateQRange()
214        # Switch off Data2D control
215        self.chk2DView.setEnabled(False)
216        self.chk2DView.setVisible(False)
217        self.chkMagnetism.setEnabled(True)
218        # Similarly on other tabs
219        self.options_widget.setEnablementOnDataLoad()
220
221        # Smearing tab
222        self.smearing_widget.updateSmearing(self.data)
223
224    def acceptsData(self):
225        """ Tells the caller this widget can accept new dataset """
226        return not self.data_is_loaded
227
228    def disableModelCombo(self):
229        """ Disable the combobox """
230        self.cbModel.setEnabled(False)
231        self.lblModel.setEnabled(False)
232
233    def enableModelCombo(self):
234        """ Enable the combobox """
235        self.cbModel.setEnabled(True)
236        self.lblModel.setEnabled(True)
237
238    def disableStructureCombo(self):
239        """ Disable the combobox """
240        self.cbStructureFactor.setEnabled(False)
241        self.lblStructure.setEnabled(False)
242
243    def enableStructureCombo(self):
244        """ Enable the combobox """
245        self.cbStructureFactor.setEnabled(True)
246        self.lblStructure.setEnabled(True)
247
248    def togglePoly(self, isChecked):
249        """ Enable/disable the polydispersity tab """
250        self.tabFitting.setTabEnabled(TAB_POLY, isChecked)
251
252    def toggleMagnetism(self, isChecked):
253        """ Enable/disable the magnetism tab """
254        self.tabFitting.setTabEnabled(TAB_MAGNETISM, isChecked)
255
256    def toggle2D(self, isChecked):
257        """ Enable/disable the controls dependent on 1D/2D data instance """
258        self.chkMagnetism.setEnabled(isChecked)
259        self.is2D = isChecked
260        # Reload the current model
261        if self.kernel_module:
262            self.onSelectModel()
263
264    def initializeControls(self):
265        """
266        Set initial control enablement
267        """
268        self.cmdFit.setEnabled(False)
269        self.cmdPlot.setEnabled(False)
270        self.options_widget.cmdComputePoints.setVisible(False) # probably redundant
271        self.chkPolydispersity.setEnabled(True)
272        self.chkPolydispersity.setCheckState(False)
273        self.chk2DView.setEnabled(True)
274        self.chk2DView.setCheckState(False)
275        self.chkMagnetism.setEnabled(False)
276        self.chkMagnetism.setCheckState(False)
277        # Tabs
278        self.tabFitting.setTabEnabled(TAB_POLY, False)
279        self.tabFitting.setTabEnabled(TAB_MAGNETISM, False)
280        self.lblChi2Value.setText("---")
281        # Smearing tab
282        self.smearing_widget.updateSmearing(self.data)
283        # Line edits in the option tab
284        self.updateQRange()
285
286    def initializeSignals(self):
287        """
288        Connect GUI element signals
289        """
290        # Comboboxes
291        self.cbStructureFactor.currentIndexChanged.connect(self.onSelectStructureFactor)
292        self.cbCategory.currentIndexChanged.connect(self.onSelectCategory)
293        self.cbModel.currentIndexChanged.connect(self.onSelectModel)
294        # Checkboxes
295        self.chk2DView.toggled.connect(self.toggle2D)
296        self.chkPolydispersity.toggled.connect(self.togglePoly)
297        self.chkMagnetism.toggled.connect(self.toggleMagnetism)
298        # Buttons
299        self.cmdFit.clicked.connect(self.onFit)
300        self.cmdPlot.clicked.connect(self.onPlot)
301        self.cmdHelp.clicked.connect(self.onHelp)
302
303        # Respond to change in parameters from the UI
304        self._model_model.itemChanged.connect(self.updateParamsFromModel)
305        self._poly_model.itemChanged.connect(self.onPolyModelChange)
306        # TODO after the poly_model prototype accepted
307        #self._magnet_model.itemChanged.connect(self.onMagneticModelChange)
308
309        # Signals from separate tabs asking for replot
310        self.options_widget.plot_signal.connect(self.onOptionsUpdate)
311
312    def showModelDescription(self, position):
313        """
314        Shows a window with model description, when right clicked in the treeview
315        """
316        msg = 'Model description:\n'
317        info = "Info"
318        if self.kernel_module is not None:
319            if str(self.kernel_module.description).rstrip().lstrip() == '':
320                msg += "Sorry, no information is available for this model."
321            else:
322                msg += self.kernel_module.description + '\n'
323        else:
324            msg += "You must select a model to get information on this"
325
326        menu = QtGui.QMenu()
327        label = QtGui.QLabel(msg)
328        action = QtGui.QWidgetAction(self)
329        action.setDefaultWidget(label)
330        menu.addAction(action)
331        menu.exec_(self.lstParams.viewport().mapToGlobal(position))
332
333    def onSelectModel(self):
334        """
335        Respond to select Model from list event
336        """
337        model = str(self.cbModel.currentText())
338
339        # Reset structure factor
340        self.cbStructureFactor.setCurrentIndex(0)
341
342        # Reset parameters to fit
343        self.parameters_to_fit = None
344        self.has_error_column = False
345
346        # Set enablement on calculate/plot
347        self.cmdPlot.setEnabled(True)
348
349        # SasModel -> QModel
350        self.SASModelToQModel(model)
351
352        if self.data_is_loaded:
353            self.cmdPlot.setText("Show Plot")
354            self.calculateQGridForModel()
355        else:
356            self.cmdPlot.setText("Calculate")
357            # Create default datasets if no data passed
358            self.createDefaultDataset()
359
360        #state = self.currentState()
361
362    def onSelectStructureFactor(self):
363        """
364        Select Structure Factor from list
365        """
366        model = str(self.cbModel.currentText())
367        category = str(self.cbCategory.currentText())
368        structure = str(self.cbStructureFactor.currentText())
369        if category == CATEGORY_STRUCTURE:
370            model = None
371        self.SASModelToQModel(model, structure_factor=structure)
372
373    def onSelectCategory(self):
374        """
375        Select Category from list
376        """
377        category = str(self.cbCategory.currentText())
378        # Check if the user chose "Choose category entry"
379        if category == CATEGORY_DEFAULT:
380            # if the previous category was not the default, keep it.
381            # Otherwise, just return
382            if self._previous_category_index != 0:
383                # We need to block signals, or else state changes on perceived unchanged conditions
384                self.cbCategory.blockSignals(True)
385                self.cbCategory.setCurrentIndex(self._previous_category_index)
386                self.cbCategory.blockSignals(False)
387            return
388
389        if category == CATEGORY_STRUCTURE:
390            self.disableModelCombo()
391            self.enableStructureCombo()
392            self._model_model.clear()
393            return
394
395        # Safely clear and enable the model combo
396        self.cbModel.blockSignals(True)
397        self.cbModel.clear()
398        self.cbModel.blockSignals(False)
399        self.enableModelCombo()
400        self.disableStructureCombo()
401
402        self._previous_category_index = self.cbCategory.currentIndex()
403        # Retrieve the list of models
404        model_list = self.master_category_dict[category]
405        models = []
406        # Populate the models combobox
407        self.cbModel.addItems(sorted([model for (model, _) in model_list]))
408
409    def onPolyModelChange(self, item):
410        """
411        Callback method for updating the main model and sasmodel
412        parameters with the GUI values in the polydispersity view
413        """
414        model_column = item.column()
415        model_row = item.row()
416        name_index = self._poly_model.index(model_row, 0)
417        # Extract changed value. Assumes proper validation by QValidator/Delegate
418        # Checkbox in column 0
419        if model_column == 0:
420            value = item.checkState()
421        else:
422            try:
423                value = float(item.text())
424            except ValueError:
425                # Can't be converted properly, bring back the old value and exit
426                return
427
428        parameter_name = str(self._poly_model.data(name_index).toPyObject()) # "distribution of sld" etc.
429        if "Distribution of" in parameter_name:
430            parameter_name = parameter_name[16:]
431        property_name = str(self._poly_model.headerData(model_column, 1).toPyObject()) # Value, min, max, etc.
432        # print "%s(%s) => %d" % (parameter_name, property_name, value)
433
434        # Update the sasmodel
435        #self.kernel_module.params[parameter_name] = value
436
437        # Reload the main model - may not be required if no variable is shown in main view
438        #model = str(self.cbModel.currentText())
439        #self.SASModelToQModel(model)
440
441        pass # debug anchor
442
443    def onHelp(self):
444        """
445        Show the "Fitting" section of help
446        """
447        tree_location = self.parent.HELP_DIRECTORY_LOCATION +\
448            "/user/sasgui/perspectives/fitting/fitting_help.html"
449        self.helpView.load(QtCore.QUrl(tree_location))
450        self.helpView.show()
451
452    def onFit(self):
453        """
454        Perform fitting on the current data
455        """
456        fitter = Fit()
457
458        # Data going in
459        data = self.logic.data
460        model = self.kernel_module
461        qmin = self.q_range_min
462        qmax = self.q_range_max
463        params_to_fit = self.parameters_to_fit
464
465        # Potential weights added directly to data
466        self.addWeightingToData(data)
467
468        # Potential smearing added
469        # Remember that smearing_min/max can be None ->
470        # deal with it until Python gets discriminated unions
471        smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state()
472
473        # These should be updating somehow?
474        fit_id = 0
475        constraints = []
476        smearer = None
477        page_id = [210]
478        handler = None
479        batch_inputs = {}
480        batch_outputs = {}
481        list_page_id = [page_id]
482        #---------------------------------
483
484        # Parameterize the fitter
485        fitter.set_model(model, fit_id, params_to_fit, data=data,
486                         constraints=constraints)
487
488        fitter.set_data(data=data, id=fit_id, smearer=smearer, qmin=qmin,
489                        qmax=qmax)
490        fitter.select_problem_for_fit(id=fit_id, value=1)
491
492        fitter.fitter_id = page_id
493
494        # Create the fitting thread, based on the fitter
495        calc_fit = FitThread(handler=handler,
496                             fn=[fitter],
497                             batch_inputs=batch_inputs,
498                             batch_outputs=batch_outputs,
499                             page_id=list_page_id,
500                             updatefn=self.updateFit,
501                             completefn=None)
502
503        # start the trhrhread
504        calc_thread = threads.deferToThread(calc_fit.compute)
505        calc_thread.addCallback(self.fitComplete)
506        calc_thread.addErrback(self.fitFailed)
507
508        #disable the Fit button
509        self.cmdFit.setText('Calculating...')
510        self.communicate.statusBarUpdateSignal.emit('Fitting started...')
511        self.cmdFit.setEnabled(False)
512
513    def updateFit(self):
514        """
515        """
516        print "UPDATE FIT"
517        pass
518
519    def fitFailed(self, reason):
520        """
521        """
522        print "FIT FAILED: ", reason
523        pass
524
525    def fitComplete(self, result):
526        """
527        Receive and display fitting results
528        "result" is a tuple of actual result list and the fit time in seconds
529        """
530        #re-enable the Fit button
531        self.cmdFit.setText("Fit")
532        self.cmdFit.setEnabled(True)
533
534        assert result is not None
535
536        res_list = result[0]
537        res = res_list[0]
538        if res.fitness is None or \
539            not np.isfinite(res.fitness) or \
540            np.any(res.pvec == None) or \
541            not np.all(np.isfinite(res.pvec)):
542            msg = "Fitting did not converge!!!"
543            self.communicate.statusBarUpdateSignal.emit(msg)
544            logging.error(msg)
545            return
546
547        elapsed = result[1]
548        msg = "Fitting completed successfully in: %s s.\n" % GuiUtils.formatNumber(elapsed)
549        self.communicate.statusBarUpdateSignal.emit(msg)
550
551        self.chi2 = res.fitness
552        param_list = res.param_list
553        param_values = res.pvec
554        param_stderr = res.stderr
555        params_and_errors = zip(param_values, param_stderr)
556        param_dict = dict(izip(param_list, params_and_errors))
557
558        # Dictionary of fitted parameter: value, error
559        # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)}
560        self.updateModelFromList(param_dict)
561
562        # update charts
563        self.onPlot()
564
565        # Read only value - we can get away by just printing it here
566        chi2_repr = GuiUtils.formatNumber(self.chi2, high=True)
567        self.lblChi2Value.setText(chi2_repr)
568
569    def iterateOverModel(self, func):
570        """
571        Take func and throw it inside the model row loop
572        """
573        #assert isinstance(func, function)
574        for row_i in xrange(self._model_model.rowCount()):
575            func(row_i)
576
577    def updateModelFromList(self, param_dict):
578        """
579        Update the model with new parameters, create the errors column
580        """
581        assert isinstance(param_dict, dict)
582        if not dict:
583            return
584
585        def updateFittedValues(row_i):
586            # Utility function for main model update
587            # internal so can use closure for param_dict
588            param_name = str(self._model_model.item(row_i, 0).text())
589            if param_name not in param_dict.keys():
590                return
591            # modify the param value
592            param_repr = GuiUtils.formatNumber(param_dict[param_name][0], high=True)
593            self._model_model.item(row_i, 1).setText(param_repr)
594            if self.has_error_column:
595                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
596                self._model_model.item(row_i, 2).setText(error_repr)
597
598        def createErrorColumn(row_i):
599            # Utility function for error column update
600            item = QtGui.QStandardItem()
601            for param_name in param_dict.keys():
602                if str(self._model_model.item(row_i, 0).text()) != param_name:
603                    continue
604                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
605                item.setText(error_repr)
606            error_column.append(item)
607
608        self.iterateOverModel(updateFittedValues)
609
610        if self.has_error_column:
611            return
612
613        error_column = []
614        self.iterateOverModel(createErrorColumn)
615
616        # switch off reponse to model change
617        self._model_model.blockSignals(True)
618        self._model_model.insertColumn(2, error_column)
619        self._model_model.blockSignals(False)
620        FittingUtilities.addErrorHeadersToModel(self._model_model)
621        # Adjust the table cells width.
622        # TODO: find a way to dynamically adjust column width while resized expanding
623        self.lstParams.resizeColumnToContents(0)
624        self.lstParams.resizeColumnToContents(4)
625        self.lstParams.resizeColumnToContents(5)
626        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
627
628        self.has_error_column = True
629
630    def onPlot(self):
631        """
632        Plot the current set of data
633        """
634        # Regardless of previous state, this should now be `plot show` functionality only
635        self.cmdPlot.setText("Show Plot")
636        self.showPlot()
637
638    def recalculatePlotData(self):
639        """
640        Generate a new dataset for model
641        """
642        if not self.data_is_loaded:
643            self.createDefaultDataset()
644        self.calculateQGridForModel()
645
646    def showPlot(self):
647        """
648        Show the current plot in MPL
649        """
650        # Show the chart if ready
651        data_to_show = self.data if self.data_is_loaded else self.model_data
652        if data_to_show is not None:
653            self.communicate.plotRequestedSignal.emit([data_to_show])
654
655    def onOptionsUpdate(self):
656        """
657        Update local option values and replot
658        """
659        self.q_range_min, self.q_range_max, self.npts, self.log_points, self.weighting = \
660            self.options_widget.state()
661        # set Q range labels on the main tab
662        self.lblMinRangeDef.setText(str(self.q_range_min))
663        self.lblMaxRangeDef.setText(str(self.q_range_max))
664        self.recalculatePlotData()
665
666    def setDefaultStructureCombo(self):
667        """
668        Fill in the structure factors combo box with defaults
669        """
670        structure_factor_list = self.master_category_dict.pop(CATEGORY_STRUCTURE)
671        factors = [factor[0] for factor in structure_factor_list]
672        factors.insert(0, STRUCTURE_DEFAULT)
673        self.cbStructureFactor.clear()
674        self.cbStructureFactor.addItems(sorted(factors))
675
676    def createDefaultDataset(self):
677        """
678        Generate default Dataset 1D/2D for the given model
679        """
680        # Create default datasets if no data passed
681        if self.is2D:
682            qmax = self.q_range_max/np.sqrt(2)
683            qstep = self.npts
684            self.logic.createDefault2dData(qmax, qstep, self.tab_id)
685            return
686        elif self.log_points:
687            qmin = -10.0 if self.q_range_min < 1.e-10 else np.log10(self.q_range_min)
688            qmax =  10.0 if self.q_range_max > 1.e10 else np.log10(self.q_range_max)
689            interval = np.logspace(start=qmin, stop=qmax, num=self.npts, endpoint=True, base=10.0)
690        else:
691            interval = np.linspace(start=self.q_range_min, stop=self.q_range_max,
692                    num=self.npts, endpoint=True)
693        self.logic.createDefault1dData(interval, self.tab_id)
694
695    def readCategoryInfo(self):
696        """
697        Reads the categories in from file
698        """
699        self.master_category_dict = defaultdict(list)
700        self.by_model_dict = defaultdict(list)
701        self.model_enabled_dict = defaultdict(bool)
702
703        categorization_file = CategoryInstaller.get_user_file()
704        if not os.path.isfile(categorization_file):
705            categorization_file = CategoryInstaller.get_default_file()
706        with open(categorization_file, 'rb') as cat_file:
707            self.master_category_dict = json.load(cat_file)
708            self.regenerateModelDict()
709
710        # Load the model dict
711        models = load_standard_models()
712        for model in models:
713            self.models[model.name] = model
714
715    def regenerateModelDict(self):
716        """
717        Regenerates self.by_model_dict which has each model name as the
718        key and the list of categories belonging to that model
719        along with the enabled mapping
720        """
721        self.by_model_dict = defaultdict(list)
722        for category in self.master_category_dict:
723            for (model, enabled) in self.master_category_dict[category]:
724                self.by_model_dict[model].append(category)
725                self.model_enabled_dict[model] = enabled
726
727    def addBackgroundToModel(self, model):
728        """
729        Adds background parameter with default values to the model
730        """
731        assert isinstance(model, QtGui.QStandardItemModel)
732        checked_list = ['background', '0.001', '-inf', 'inf', '1/cm']
733        FittingUtilities.addCheckedListToModel(model, checked_list)
734        last_row = model.rowCount()-1
735        model.item(last_row, 0).setEditable(False)
736        model.item(last_row, 4).setEditable(False)
737
738    def addScaleToModel(self, model):
739        """
740        Adds scale parameter with default values to the model
741        """
742        assert isinstance(model, QtGui.QStandardItemModel)
743        checked_list = ['scale', '1.0', '0.0', 'inf', '']
744        FittingUtilities.addCheckedListToModel(model, checked_list)
745        last_row = model.rowCount()-1
746        model.item(last_row, 0).setEditable(False)
747        model.item(last_row, 4).setEditable(False)
748
749    def addWeightingToData(self, data):
750        """
751        Adds weighting contribution to fitting data
752        #"""
753        # Send original data for weighting
754        weight = get_weight(data=data, is2d=self.is2D, flag=self.weighting)
755        update_module = data.err_data if self.is2D else data.dy
756        update_module = weight
757
758    def updateQRange(self):
759        """
760        Updates Q Range display
761        """
762        if self.data_is_loaded:
763            self.q_range_min, self.q_range_max, self.npts = self.logic.computeDataRange()
764        # set Q range labels on the main tab
765        self.lblMinRangeDef.setText(str(self.q_range_min))
766        self.lblMaxRangeDef.setText(str(self.q_range_max))
767        # set Q range labels on the options tab
768        self.options_widget.updateQRange(self.q_range_min, self.q_range_max, self.npts)
769
770    def SASModelToQModel(self, model_name, structure_factor=None):
771        """
772        Setting model parameters into table based on selected category
773        """
774        # TODO - modify for structure factor-only choice
775
776        # Crete/overwrite model items
777        self._model_model.clear()
778
779        kernel_module = generate.load_kernel_module(model_name)
780        self.model_parameters = modelinfo.make_parameter_table(getattr(kernel_module, 'parameters', []))
781
782        # Instantiate the current sasmodel
783        self.kernel_module = self.models[model_name]()
784
785        # Explicitly add scale and background with default values
786        self.addScaleToModel(self._model_model)
787        self.addBackgroundToModel(self._model_model)
788
789        # Update the QModel
790        new_rows = FittingUtilities.addParametersToModel(self.model_parameters, self.is2D)
791        for row in new_rows:
792            self._model_model.appendRow(row)
793        # Update the counter used for multishell display
794        self._last_model_row = self._model_model.rowCount()
795
796        FittingUtilities.addHeadersToModel(self._model_model)
797
798        # Add structure factor
799        if structure_factor is not None and structure_factor != "None":
800            structure_module = generate.load_kernel_module(structure_factor)
801            structure_parameters = modelinfo.make_parameter_table(getattr(structure_module, 'parameters', []))
802            new_rows = FittingUtilities.addSimpleParametersToModel(structure_parameters, self.is2D)
803            for row in new_rows:
804                self._model_model.appendRow(row)
805            # Update the counter used for multishell display
806            self._last_model_row = self._model_model.rowCount()
807        else:
808            self.addStructureFactor()
809
810        # Multishell models need additional treatment
811        self.addExtraShells()
812
813        # Add polydispersity to the model
814        self.setPolyModel()
815        # Add magnetic parameters to the model
816        self.setMagneticModel()
817
818        # Adjust the table cells width
819        self.lstParams.resizeColumnToContents(0)
820        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
821
822        # Now we claim the model has been loaded
823        self.model_is_loaded = True
824
825        # Update Q Ranges
826        self.updateQRange()
827
828    def updateParamsFromModel(self, item):
829        """
830        Callback method for updating the sasmodel parameters with the GUI values
831        """
832        model_column = item.column()
833
834        if model_column == 0:
835            self.checkboxSelected(item)
836            self.cmdFit.setEnabled(self.parameters_to_fit != [] and self.logic.data_is_loaded)
837            return
838
839        model_row = item.row()
840        name_index = self._model_model.index(model_row, 0)
841
842        # Extract changed value. Assumes proper validation by QValidator/Delegate
843        # TODO: disable model update for uneditable cells/columns
844        try:
845            value = float(item.text())
846        except ValueError:
847            # Unparsable field
848            return
849        parameter_name = str(self._model_model.data(name_index).toPyObject()) # sld, background etc.
850        property_name = str(self._model_model.headerData(1, model_column).toPyObject()) # Value, min, max, etc.
851
852        self.kernel_module.params[parameter_name] = value
853
854        # min/max to be changed in self.kernel_module.details[parameter_name] = ['Ang', 0.0, inf]
855        # magnetic params in self.kernel_module.details['M0:parameter_name'] = value
856        # multishell params in self.kernel_module.details[??] = value
857
858        # Force the chart update when actual parameters changed
859        if model_column == 1:
860            self.recalculatePlotData()
861
862    def checkboxSelected(self, item):
863        # Assure we're dealing with checkboxes
864        if not item.isCheckable():
865            return
866        status = item.checkState()
867
868        def isChecked(row):
869            return self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked
870
871        def isCheckable(row):
872            return self._model_model.item(row, 0).isCheckable()
873
874        # If multiple rows selected - toggle all of them, filtering uncheckable
875        rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() if isCheckable(s.row())]
876
877        # Switch off signaling from the model to avoid recursion
878        self._model_model.blockSignals(True)
879        # Convert to proper indices and set requested enablement
880        items = [self._model_model.item(row, 0).setCheckState(status) for row in rows]
881        self._model_model.blockSignals(False)
882
883        # update the list of parameters to fit
884        self.parameters_to_fit = [str(self._model_model.item(row_index, 0).text())
885                                  for row_index in xrange(self._model_model.rowCount())
886                                  if isChecked(row_index)]
887
888    def nameForFittedData(self, name):
889        """
890        Generate name for the current fit
891        """
892        if self.is2D:
893            name += "2d"
894        name = "M%i [%s]" % (self.tab_id, name)
895        return name
896
897    def createNewIndex(self, fitted_data):
898        """
899        Create a model or theory index with passed Data1D/Data2D
900        """
901        if self.data_is_loaded:
902            if not fitted_data.name:
903                name = self.nameForFittedData(self.data.filename)
904                fitted_data.title = name
905                fitted_data.name = name
906                fitted_data.filename = name
907                fitted_data.symbol = "Line"
908            self.updateModelIndex(fitted_data)
909        else:
910            name = self.nameForFittedData(self.kernel_module.name)
911            fitted_data.title = name
912            fitted_data.name = name
913            fitted_data.filename = name
914            fitted_data.symbol = "Line"
915            self.createTheoryIndex(fitted_data)
916
917    def updateModelIndex(self, fitted_data):
918        """
919        Update a QStandardModelIndex containing model data
920        """
921        if fitted_data.name is None:
922            name = self.nameForFittedData(self.logic.data.filename)
923            fitted_data.title = name
924            fitted_data.name = name
925        else:
926            name = fitted_data.name
927        # Make this a line if no other defined
928        if hasattr(fitted_data, 'symbol') and fitted_data.symbol is None:
929            fitted_data.symbol = 'Line'
930        # Notify the GUI manager so it can update the main model in DataExplorer
931        GuiUtils.updateModelItemWithPlot(self._index, QtCore.QVariant(fitted_data), name)
932
933    def createTheoryIndex(self, fitted_data):
934        """
935        Create a QStandardModelIndex containing model data
936        """
937        if fitted_data.name is None:
938            name = self.nameForFittedData(self.kernel_module.name)
939            fitted_data.title = name
940            fitted_data.name = name
941            fitted_data.filename = name
942        else:
943            name = fitted_data.name
944        # Notify the GUI manager so it can create the theory model in DataExplorer
945        new_item = GuiUtils.createModelItemWithPlot(QtCore.QVariant(fitted_data), name=name)
946        self.communicate.updateTheoryFromPerspectiveSignal.emit(new_item)
947
948    def methodCalculateForData(self):
949        '''return the method for data calculation'''
950        return Calc1D if isinstance(self.data, Data1D) else Calc2D
951
952    def methodCompleteForData(self):
953        '''return the method for result parsin on calc complete '''
954        return self.complete1D if isinstance(self.data, Data1D) else self.complete2D
955
956    def calculateQGridForModel(self):
957        """
958        Prepare the fitting data object, based on current ModelModel
959        """
960        if self.kernel_module is None:
961            return
962        # Awful API to a backend method.
963        method = self.methodCalculateForData()(data=self.data,
964                              model=self.kernel_module,
965                              page_id=0,
966                              qmin=self.q_range_min,
967                              qmax=self.q_range_max,
968                              smearer=None,
969                              state=None,
970                              weight=None,
971                              fid=None,
972                              toggle_mode_on=False,
973                              completefn=None,
974                              update_chisqr=True,
975                              exception_handler=self.calcException,
976                              source=None)
977
978        calc_thread = threads.deferToThread(method.compute)
979        calc_thread.addCallback(self.methodCompleteForData())
980
981    def complete1D(self, return_data):
982        """
983        Plot the current 1D data
984        """
985        fitted_data = self.logic.new1DPlot(return_data, self.tab_id)
986        self.calculateResiduals(fitted_data)
987        self.model_data = fitted_data
988
989    def complete2D(self, return_data):
990        """
991        Plot the current 2D data
992        """
993        fitted_data = self.logic.new2DPlot(return_data)
994        self.calculateResiduals(fitted_data)
995        self.model_data = fitted_data
996
997    def calculateResiduals(self, fitted_data):
998        """
999        Calculate and print Chi2 and display chart of residuals
1000        """
1001        # Create a new index for holding data
1002        fitted_data.symbol = "Line"
1003        self.createNewIndex(fitted_data)
1004        # Calculate difference between return_data and logic.data
1005        self.chi2 = FittingUtilities.calculateChi2(fitted_data, self.logic.data)
1006        # Update the control
1007        chi2_repr = "---" if self.chi2 is None else GuiUtils.formatNumber(self.chi2, high=True)
1008        self.lblChi2Value.setText(chi2_repr)
1009
1010        self.communicate.plotUpdateSignal.emit([fitted_data])
1011
1012        # Plot residuals if actual data
1013        if self.data_is_loaded:
1014            residuals_plot = FittingUtilities.plotResiduals(self.data, fitted_data)
1015            residuals_plot.id = "Residual " + residuals_plot.id
1016            self.createNewIndex(residuals_plot)
1017            self.communicate.plotUpdateSignal.emit([residuals_plot])
1018
1019    def calcException(self, etype, value, tb):
1020        """
1021        Something horrible happened in the deferred.
1022        """
1023        logging.error("".join(traceback.format_exception(etype, value, tb)))
1024
1025    def setTableProperties(self, table):
1026        """
1027        Setting table properties
1028        """
1029        # Table properties
1030        table.verticalHeader().setVisible(False)
1031        table.setAlternatingRowColors(True)
1032        table.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
1033        table.setSelectionBehavior(QtGui.QAbstractItemView.SelectRows)
1034        table.resizeColumnsToContents()
1035
1036        # Header
1037        header = table.horizontalHeader()
1038        header.setResizeMode(QtGui.QHeaderView.ResizeToContents)
1039
1040        header.ResizeMode(QtGui.QHeaderView.Interactive)
1041        # Resize column 0 and 6 to content
1042        header.setResizeMode(0, QtGui.QHeaderView.ResizeToContents)
1043        header.setResizeMode(6, QtGui.QHeaderView.ResizeToContents)
1044
1045    def setPolyModel(self):
1046        """
1047        Set polydispersity values
1048        """
1049        if not self.model_parameters:
1050            return
1051        self._poly_model.clear()
1052        for row, param in enumerate(self.model_parameters.form_volume_parameters):
1053            # Counters should not be included
1054            if not param.polydisperse:
1055                continue
1056
1057            # Potential multishell params
1058            checked_list = ["Distribution of "+param.name, str(param.default),
1059                            str(param.limits[0]), str(param.limits[1]),
1060                            "35", "3", ""]
1061            FittingUtilities.addCheckedListToModel(self._poly_model, checked_list)
1062
1063            #TODO: Need to find cleaner way to input functions
1064            func = QtGui.QComboBox()
1065            func.addItems(['rectangle', 'array', 'lognormal', 'gaussian', 'schulz',])
1066            func_index = self.lstPoly.model().index(row, 6)
1067            self.lstPoly.setIndexWidget(func_index, func)
1068
1069        FittingUtilities.addPolyHeadersToModel(self._poly_model)
1070
1071    def setMagneticModel(self):
1072        """
1073        Set magnetism values on model
1074        """
1075        if not self.model_parameters:
1076            return
1077        self._magnet_model.clear()
1078        for param in self.model_parameters.call_parameters:
1079            if param.type != "magnetic":
1080                continue
1081            checked_list = [param.name,
1082                            str(param.default),
1083                            str(param.limits[0]),
1084                            str(param.limits[1]),
1085                            param.units]
1086            FittingUtilities.addCheckedListToModel(self._magnet_model, checked_list)
1087
1088        FittingUtilities.addHeadersToModel(self._magnet_model)
1089
1090    def addStructureFactor(self):
1091        """
1092        Add structure factors to the list of parameters
1093        """
1094        if self.kernel_module.is_form_factor:
1095            self.enableStructureCombo()
1096        else:
1097            self.disableStructureCombo()
1098
1099    def addExtraShells(self):
1100        """
1101        Add a combobox for multiple shell display
1102        """
1103        param_name, param_length = FittingUtilities.getMultiplicity(self.model_parameters)
1104
1105        if param_length == 0:
1106            return
1107
1108        # cell 1: variable name
1109        item1 = QtGui.QStandardItem(param_name)
1110
1111        func = QtGui.QComboBox()
1112        # Available range of shells displayed in the combobox
1113        func.addItems([str(i) for i in xrange(param_length+1)])
1114
1115        # Respond to index change
1116        func.currentIndexChanged.connect(self.modifyShellsInList)
1117
1118        # cell 2: combobox
1119        item2 = QtGui.QStandardItem()
1120        self._model_model.appendRow([item1, item2])
1121
1122        # Beautify the row:  span columns 2-4
1123        shell_row = self._model_model.rowCount()
1124        shell_index = self._model_model.index(shell_row-1, 1)
1125
1126        self.lstParams.setIndexWidget(shell_index, func)
1127        self._last_model_row = self._model_model.rowCount()
1128
1129        # Set the index to the state-kept value
1130        func.setCurrentIndex(self.current_shell_displayed
1131                             if self.current_shell_displayed < func.count() else 0)
1132
1133    def modifyShellsInList(self, index):
1134        """
1135        Add/remove additional multishell parameters
1136        """
1137        # Find row location of the combobox
1138        last_row = self._last_model_row
1139        remove_rows = self._model_model.rowCount() - last_row
1140
1141        if remove_rows > 1:
1142            self._model_model.removeRows(last_row, remove_rows)
1143
1144        FittingUtilities.addShellsToModel(self.model_parameters, self._model_model, index)
1145        self.current_shell_displayed = index
1146
Note: See TracBrowser for help on using the repository browser.