source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @ 672b8ab

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 672b8ab was 672b8ab, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Further fitpage implementation with tests SASVIEW-570

  • Property mode set to 100644
File size: 45.5 KB
Line 
1import sys
2import json
3import os
4import numpy as np
5from collections import defaultdict
6from itertools import izip
7
8import logging
9import traceback
10from twisted.internet import threads
11
12from PyQt4 import QtGui
13from PyQt4 import QtCore
14from PyQt4 import QtWebKit
15
16from sasmodels import generate
17from sasmodels import modelinfo
18from sasmodels.sasview_model import load_standard_models
19from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit
20from sas.sasgui.perspectives.fitting.fit_thread import FitThread
21
22from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
23from sas.sasgui.guiframe.dataFitting import Data1D
24from sas.sasgui.guiframe.dataFitting import Data2D
25import sas.qtgui.Utilities.GuiUtils as GuiUtils
26from sas.sasgui.perspectives.fitting.model_thread import Calc1D
27from sas.sasgui.perspectives.fitting.model_thread import Calc2D
28from sas.sasgui.perspectives.fitting.utils import get_weight
29
30from UI.FittingWidgetUI import Ui_FittingWidgetUI
31from sas.qtgui.Perspectives.Fitting.FittingLogic import FittingLogic
32from sas.qtgui.Perspectives.Fitting import FittingUtilities
33from SmearingWidget import SmearingWidget
34from OptionsWidget import OptionsWidget
35from FitPage import FitPage
36
37TAB_MAGNETISM = 4
38TAB_POLY = 3
39CATEGORY_DEFAULT = "Choose category..."
40CATEGORY_STRUCTURE = "Structure Factor"
41STRUCTURE_DEFAULT = "None"
42
43class FittingWidget(QtGui.QWidget, Ui_FittingWidgetUI):
44    """
45    Main widget for selecting form and structure factor models
46    """
47    def __init__(self, parent=None, data=None, id=1):
48
49        super(FittingWidget, self).__init__()
50
51        # Necessary globals
52        self.parent = parent
53        # SasModel is loaded
54        self.model_is_loaded = False
55        # Data[12]D passed and set
56        self.data_is_loaded = False
57        # Current SasModel in view
58        self.kernel_module = None
59        # Current SasModel view dimension
60        self.is2D = False
61        # Current SasModel is multishell
62        self.model_has_shells = False
63        # Utility variable to enable unselectable option in category combobox
64        self._previous_category_index = 0
65        # Utility variable for multishell display
66        self._last_model_row = 0
67        # Dictionary of {model name: model class} for the current category
68        self.models = {}
69        # Parameters to fit
70        self.parameters_to_fit = None
71        # Fit options
72        self.q_range_min = 0.005
73        self.q_range_max = 0.1
74        self.npts = 25
75        self.log_points = False
76        self.weighting = 0
77        self.chi2 = None
78
79        # Data for chosen model
80        self.model_data = None
81
82        # Which tab is this widget displayed in?
83        self.tab_id = id
84
85        # Which shell is being currently displayed?
86        self.current_shell_displayed = 0
87        self.has_error_column = False
88
89        # Main Data[12]D holder
90        self.logic = FittingLogic(data=data)
91
92        # Main GUI setup up
93        self.setupUi(self)
94        self.setWindowTitle("Fitting")
95        self.communicate = self.parent.communicate
96
97        # Options widget
98        layout = QtGui.QGridLayout()
99        self.options_widget = OptionsWidget(self, self.logic)
100        layout.addWidget(self.options_widget) 
101        self.tabOptions.setLayout(layout)
102
103        # Smearing widget
104        layout = QtGui.QGridLayout()
105        self.smearing_widget = SmearingWidget(self)
106        layout.addWidget(self.smearing_widget) 
107        self.tabResolution.setLayout(layout)
108
109        # Define bold font for use in various controls
110        self.boldFont=QtGui.QFont()
111        self.boldFont.setBold(True)
112
113        # Set data label
114        self.label.setFont(self.boldFont)
115        self.label.setText("No data loaded")
116        self.lblFilename.setText("")
117
118        # Set the main models
119        # We can't use a single model here, due to restrictions on flattening
120        # the model tree with subclassed QAbstractProxyModel...
121        self._model_model = QtGui.QStandardItemModel()
122        self._poly_model = QtGui.QStandardItemModel()
123        self._magnet_model = QtGui.QStandardItemModel()
124
125        # Param model displayed in param list
126        self.lstParams.setModel(self._model_model)
127        self.readCategoryInfo()
128        self.model_parameters = None
129        self.lstParams.setAlternatingRowColors(True)
130        stylesheet = """
131            QTreeView{
132                alternate-background-color: #f6fafb;
133                background: #e8f4fc;
134            }
135        """
136        self.lstParams.setStyleSheet(stylesheet)
137        self.lstParams.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
138        self.lstParams.customContextMenuRequested.connect(self.showModelDescription)
139
140        # Poly model displayed in poly list
141        self.lstPoly.setModel(self._poly_model)
142        self.setPolyModel()
143        self.setTableProperties(self.lstPoly)
144
145        # Magnetism model displayed in magnetism list
146        self.lstMagnetic.setModel(self._magnet_model)
147        self.setMagneticModel()
148        self.setTableProperties(self.lstMagnetic)
149
150        # Defaults for the structure factors
151        self.setDefaultStructureCombo()
152
153        # Make structure factor and model CBs disabled
154        self.disableModelCombo()
155        self.disableStructureCombo()
156
157        # Generate the category list for display
158        category_list = sorted(self.master_category_dict.keys())
159        self.cbCategory.addItem(CATEGORY_DEFAULT)
160        self.cbCategory.addItems(category_list)
161        self.cbCategory.addItem(CATEGORY_STRUCTURE)
162        self.cbCategory.setCurrentIndex(0)
163
164        # Connect signals to controls
165        self.initializeSignals()
166
167        # Initial control state
168        self.initializeControls()
169
170        # Display HTML content
171        self.helpView = QtWebKit.QWebView()
172
173        self._index = None
174        if data is not None:
175            self.data = data
176
177    def close(self):
178        """
179        Remember to kill off things on exit
180        """
181        self.helpView.close()
182        del self.helpView
183
184    @property
185    def data(self):
186        return self.logic.data
187
188    @data.setter
189    def data(self, value):
190        """ data setter """
191        assert isinstance(value, QtGui.QStandardItem)
192        # _index contains the QIndex with data
193        self._index = value
194
195        # Update logics with data items
196        self.logic.data = GuiUtils.dataFromItem(value)
197
198        # Overwrite data type descriptor
199        self.is2D = True if isinstance(self.logic.data, Data2D) else False
200
201        self.data_is_loaded = True
202
203        # Enable/disable UI components
204        self.setEnablementOnDataLoad()
205
206    def setEnablementOnDataLoad(self):
207        """
208        Enable/disable various UI elements based on data loaded
209        """
210        # Tag along functionality
211        self.label.setText("Data loaded from: ")
212        self.lblFilename.setText(self.logic.data.filename)
213        self.updateQRange()
214        # Switch off Data2D control
215        self.chk2DView.setEnabled(False)
216        self.chk2DView.setVisible(False)
217        self.chkMagnetism.setEnabled(True)
218        # Similarly on other tabs
219        self.options_widget.setEnablementOnDataLoad()
220
221        # Smearing tab
222        self.smearing_widget.updateSmearing(self.data)
223
224    def acceptsData(self):
225        """ Tells the caller this widget can accept new dataset """
226        return not self.data_is_loaded
227
228    def disableModelCombo(self):
229        """ Disable the combobox """
230        self.cbModel.setEnabled(False)
231        self.lblModel.setEnabled(False)
232
233    def enableModelCombo(self):
234        """ Enable the combobox """
235        self.cbModel.setEnabled(True)
236        self.lblModel.setEnabled(True)
237
238    def disableStructureCombo(self):
239        """ Disable the combobox """
240        self.cbStructureFactor.setEnabled(False)
241        self.lblStructure.setEnabled(False)
242
243    def enableStructureCombo(self):
244        """ Enable the combobox """
245        self.cbStructureFactor.setEnabled(True)
246        self.lblStructure.setEnabled(True)
247
248    def togglePoly(self, isChecked):
249        """ Enable/disable the polydispersity tab """
250        self.tabFitting.setTabEnabled(TAB_POLY, isChecked)
251
252    def toggleMagnetism(self, isChecked):
253        """ Enable/disable the magnetism tab """
254        self.tabFitting.setTabEnabled(TAB_MAGNETISM, isChecked)
255
256    def toggle2D(self, isChecked):
257        """ Enable/disable the controls dependent on 1D/2D data instance """
258        self.chkMagnetism.setEnabled(isChecked)
259        self.is2D = isChecked
260        # Reload the current model
261        if self.kernel_module:
262            self.onSelectModel()
263
264    def initializeControls(self):
265        """
266        Set initial control enablement
267        """
268        self.cmdFit.setEnabled(False)
269        self.cmdPlot.setEnabled(False)
270        self.options_widget.cmdComputePoints.setVisible(False) # probably redundant
271        self.chkPolydispersity.setEnabled(True)
272        self.chkPolydispersity.setCheckState(False)
273        self.chk2DView.setEnabled(True)
274        self.chk2DView.setCheckState(False)
275        self.chkMagnetism.setEnabled(False)
276        self.chkMagnetism.setCheckState(False)
277        # Tabs
278        self.tabFitting.setTabEnabled(TAB_POLY, False)
279        self.tabFitting.setTabEnabled(TAB_MAGNETISM, False)
280        self.lblChi2Value.setText("---")
281        # Smearing tab
282        self.smearing_widget.updateSmearing(self.data)
283        # Line edits in the option tab
284        self.updateQRange()
285
286    def initializeSignals(self):
287        """
288        Connect GUI element signals
289        """
290        # Comboboxes
291        self.cbStructureFactor.currentIndexChanged.connect(self.onSelectStructureFactor)
292        self.cbCategory.currentIndexChanged.connect(self.onSelectCategory)
293        self.cbModel.currentIndexChanged.connect(self.onSelectModel)
294        # Checkboxes
295        self.chk2DView.toggled.connect(self.toggle2D)
296        self.chkPolydispersity.toggled.connect(self.togglePoly)
297        self.chkMagnetism.toggled.connect(self.toggleMagnetism)
298        # Buttons
299        self.cmdFit.clicked.connect(self.onFit)
300        self.cmdPlot.clicked.connect(self.onPlot)
301        self.cmdHelp.clicked.connect(self.onHelp)
302
303        # Respond to change in parameters from the UI
304        self._model_model.itemChanged.connect(self.updateParamsFromModel)
305        self._poly_model.itemChanged.connect(self.onPolyModelChange)
306        # TODO after the poly_model prototype accepted
307        #self._magnet_model.itemChanged.connect(self.onMagneticModelChange)
308
309        # Signals from separate tabs asking for replot
310        self.options_widget.plot_signal.connect(self.onOptionsUpdate)
311
312    def showModelDescription(self, position):
313        """
314        Shows a window with model description, when right clicked in the treeview
315        """
316        msg = 'Model description:\n'
317        info = "Info"
318        if self.kernel_module is not None:
319            if str(self.kernel_module.description).rstrip().lstrip() == '':
320                msg += "Sorry, no information is available for this model."
321            else:
322                msg += self.kernel_module.description + '\n'
323        else:
324            msg += "You must select a model to get information on this"
325
326        menu = QtGui.QMenu()
327        label = QtGui.QLabel(msg)
328        action = QtGui.QWidgetAction(self)
329        action.setDefaultWidget(label)
330        menu.addAction(action)
331        menu.exec_(self.lstParams.viewport().mapToGlobal(position))
332
333    def onSelectModel(self):
334        """
335        Respond to select Model from list event
336        """
337        model = str(self.cbModel.currentText())
338
339        # Reset structure factor
340        self.cbStructureFactor.setCurrentIndex(0)
341
342        # Reset parameters to fit
343        self.parameters_to_fit = None
344        self.has_error_column = False
345
346        # Set enablement on calculate/plot
347        self.cmdPlot.setEnabled(True)
348
349        # SasModel -> QModel
350        self.SASModelToQModel(model)
351
352        if self.data_is_loaded:
353            self.cmdPlot.setText("Show Plot")
354            self.calculateQGridForModel()
355        else:
356            self.cmdPlot.setText("Calculate")
357            # Create default datasets if no data passed
358            self.createDefaultDataset()
359
360        state = self.currentState()
361
362    def onSelectStructureFactor(self):
363        """
364        Select Structure Factor from list
365        """
366        model = str(self.cbModel.currentText())
367        category = str(self.cbCategory.currentText())
368        structure = str(self.cbStructureFactor.currentText())
369        if category == CATEGORY_STRUCTURE:
370            model = None
371        self.SASModelToQModel(model, structure_factor=structure)
372
373    def onSelectCategory(self):
374        """
375        Select Category from list
376        """
377        category = str(self.cbCategory.currentText())
378        # Check if the user chose "Choose category entry"
379        if category == CATEGORY_DEFAULT:
380            # if the previous category was not the default, keep it.
381            # Otherwise, just return
382            if self._previous_category_index != 0:
383                # We need to block signals, or else state changes on perceived unchanged conditions
384                self.cbCategory.blockSignals(True)
385                self.cbCategory.setCurrentIndex(self._previous_category_index)
386                self.cbCategory.blockSignals(False)
387            return
388
389        if category == CATEGORY_STRUCTURE:
390            self.disableModelCombo()
391            self.enableStructureCombo()
392            self._model_model.clear()
393            return
394
395        # Safely clear and enable the model combo
396        self.cbModel.blockSignals(True)
397        self.cbModel.clear()
398        self.cbModel.blockSignals(False)
399        self.enableModelCombo()
400        self.disableStructureCombo()
401
402        self._previous_category_index = self.cbCategory.currentIndex()
403        # Retrieve the list of models
404        model_list = self.master_category_dict[category]
405        models = []
406        # Populate the models combobox
407        self.cbModel.addItems(sorted([model for (model, _) in model_list]))
408
409    def onPolyModelChange(self, item):
410        """
411        Callback method for updating the main model and sasmodel
412        parameters with the GUI values in the polydispersity view
413        """
414        model_column = item.column()
415        model_row = item.row()
416        name_index = self._poly_model.index(model_row, 0)
417        # Extract changed value. Assumes proper validation by QValidator/Delegate
418        # Checkbox in column 0
419        if model_column == 0:
420            value = item.checkState()
421        else:
422            try:
423                value = float(item.text())
424            except ValueError:
425                # Can't be converted properly, bring back the old value and exit
426                return
427
428        parameter_name = str(self._poly_model.data(name_index).toPyObject()) # "distribution of sld" etc.
429        if "Distribution of" in parameter_name:
430            parameter_name = parameter_name[16:]
431        property_name = str(self._poly_model.headerData(model_column, 1).toPyObject()) # Value, min, max, etc.
432        # print "%s(%s) => %d" % (parameter_name, property_name, value)
433
434        # Update the sasmodel
435        #self.kernel_module.params[parameter_name] = value
436
437        # Reload the main model - may not be required if no variable is shown in main view
438        #model = str(self.cbModel.currentText())
439        #self.SASModelToQModel(model)
440
441        pass # debug anchor
442
443    def onHelp(self):
444        """
445        Show the "Fitting" section of help
446        """
447        tree_location = self.parent.HELP_DIRECTORY_LOCATION +\
448            "/user/sasgui/perspectives/fitting/fitting_help.html"
449        self.helpView.load(QtCore.QUrl(tree_location))
450        self.helpView.show()
451
452    def onFit(self):
453        """
454        Perform fitting on the current data
455        """
456        fitter = Fit()
457
458        # Data going in
459        data = self.logic.data
460        model = self.kernel_module
461        qmin = self.q_range_min
462        qmax = self.q_range_max
463        params_to_fit = self.parameters_to_fit
464
465        # Potential weights added directly to data
466        self.addWeightingToData(data)
467
468        # Potential smearing added
469        # Remember that smearing_min/max can be None ->
470        # deal with it until Python gets discriminated unions
471        smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state()
472
473        # These should be updating somehow?
474        fit_id = 0
475        constraints = []
476        smearer = None
477        page_id = [210]
478        handler = None
479        batch_inputs = {}
480        batch_outputs = {}
481        list_page_id = [page_id]
482        #---------------------------------
483
484        # Parameterize the fitter
485        fitter.set_model(model, fit_id, params_to_fit, data=data,
486                         constraints=constraints)
487
488        fitter.set_data(data=data, id=fit_id, smearer=smearer, qmin=qmin,
489                        qmax=qmax)
490        fitter.select_problem_for_fit(id=fit_id, value=1)
491
492        fitter.fitter_id = page_id
493
494        # Create the fitting thread, based on the fitter
495        calc_fit = FitThread(handler=handler,
496                             fn=[fitter],
497                             batch_inputs=batch_inputs,
498                             batch_outputs=batch_outputs,
499                             page_id=list_page_id,
500                             updatefn=self.updateFit,
501                             completefn=None)
502
503        # start the trhrhread
504        calc_thread = threads.deferToThread(calc_fit.compute)
505        calc_thread.addCallback(self.fitComplete)
506        calc_thread.addErrback(self.fitFailed)
507
508        #disable the Fit button
509        self.cmdFit.setText('Calculating...')
510        self.communicate.statusBarUpdateSignal.emit('Fitting started...')
511        self.cmdFit.setEnabled(False)
512
513    def updateFit(self):
514        """
515        """
516        print "UPDATE FIT"
517        pass
518
519    def fitFailed(self, reason):
520        """
521        """
522        print "FIT FAILED: ", reason
523        pass
524
525    def fitComplete(self, result):
526        """
527        Receive and display fitting results
528        "result" is a tuple of actual result list and the fit time in seconds
529        """
530        #re-enable the Fit button
531        self.cmdFit.setText("Fit")
532        self.cmdFit.setEnabled(True)
533
534        assert result is not None
535
536        res_list = result[0]
537        res = res_list[0]
538        if res.fitness is None or \
539            not np.isfinite(res.fitness) or \
540            np.any(res.pvec == None) or \
541            not np.all(np.isfinite(res.pvec)):
542            msg = "Fitting did not converge!!!"
543            self.communicate.statusBarUpdateSignal.emit(msg)
544            logging.error(msg)
545            return
546
547        elapsed = result[1]
548        msg = "Fitting completed successfully in: %s s.\n" % GuiUtils.formatNumber(elapsed)
549        self.communicate.statusBarUpdateSignal.emit(msg)
550
551        self.chi2 = res.fitness
552        param_list = res.param_list
553        param_values = res.pvec
554        param_stderr = res.stderr
555        params_and_errors = zip(param_values, param_stderr)
556        param_dict = dict(izip(param_list, params_and_errors))
557
558        # Dictionary of fitted parameter: value, error
559        # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)}
560        self.updateModelFromList(param_dict)
561
562        # update charts
563        self.onPlot()
564
565        # Read only value - we can get away by just printing it here
566        chi2_repr = GuiUtils.formatNumber(self.chi2, high=True)
567        self.lblChi2Value.setText(chi2_repr)
568
569    def iterateOverModel(self, func):
570        """
571        Take func and throw it inside the model row loop
572        """
573        #assert isinstance(func, function)
574        for row_i in xrange(self._model_model.rowCount()):
575            func(row_i)
576
577    def updateModelFromList(self, param_dict):
578        """
579        Update the model with new parameters, create the errors column
580        """
581        assert isinstance(param_dict, dict)
582        if not dict:
583            return
584
585        def updateFittedValues(row_i):
586            # Utility function for main model update
587            # internal so can use closure for param_dict
588            param_name = str(self._model_model.item(row_i, 0).text())
589            if param_name not in param_dict.keys():
590                return
591            # modify the param value
592            param_repr = GuiUtils.formatNumber(param_dict[param_name][0], high=True)
593            self._model_model.item(row_i, 1).setText(param_repr)
594            if self.has_error_column:
595                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
596                self._model_model.item(row_i, 2).setText(error_repr)
597
598        def createErrorColumn(row_i):
599            # Utility function for error column update
600            item = QtGui.QStandardItem()
601            for param_name in param_dict.keys():
602                if str(self._model_model.item(row_i, 0).text()) != param_name:
603                    continue
604                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
605                item.setText(error_repr)
606            error_column.append(item)
607
608        # block signals temporarily, so we don't end up
609        # updating charts with every single model change on the end of fitting
610        self._model_model.blockSignals(True)
611        self.iterateOverModel(updateFittedValues)
612        self._model_model.blockSignals(False)
613
614        if self.has_error_column:
615            return
616
617        error_column = []
618        self.iterateOverModel(createErrorColumn)
619
620        # switch off reponse to model change
621        self._model_model.blockSignals(True)
622        self._model_model.insertColumn(2, error_column)
623        self._model_model.blockSignals(False)
624        FittingUtilities.addErrorHeadersToModel(self._model_model)
625        # Adjust the table cells width.
626        # TODO: find a way to dynamically adjust column width while resized expanding
627        self.lstParams.resizeColumnToContents(0)
628        self.lstParams.resizeColumnToContents(4)
629        self.lstParams.resizeColumnToContents(5)
630        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
631
632        self.has_error_column = True
633
634    def onPlot(self):
635        """
636        Plot the current set of data
637        """
638        # Regardless of previous state, this should now be `plot show` functionality only
639        self.cmdPlot.setText("Show Plot")
640        self.showPlot()
641
642    def recalculatePlotData(self):
643        """
644        Generate a new dataset for model
645        """
646        if not self.data_is_loaded:
647            self.createDefaultDataset()
648        self.calculateQGridForModel()
649
650    def showPlot(self):
651        """
652        Show the current plot in MPL
653        """
654        # Show the chart if ready
655        data_to_show = self.data if self.data_is_loaded else self.model_data
656        if data_to_show is not None:
657            self.communicate.plotRequestedSignal.emit([data_to_show])
658
659    def onOptionsUpdate(self):
660        """
661        Update local option values and replot
662        """
663        self.q_range_min, self.q_range_max, self.npts, self.log_points, self.weighting = \
664            self.options_widget.state()
665        # set Q range labels on the main tab
666        self.lblMinRangeDef.setText(str(self.q_range_min))
667        self.lblMaxRangeDef.setText(str(self.q_range_max))
668        self.recalculatePlotData()
669
670    def setDefaultStructureCombo(self):
671        """
672        Fill in the structure factors combo box with defaults
673        """
674        structure_factor_list = self.master_category_dict.pop(CATEGORY_STRUCTURE)
675        factors = [factor[0] for factor in structure_factor_list]
676        factors.insert(0, STRUCTURE_DEFAULT)
677        self.cbStructureFactor.clear()
678        self.cbStructureFactor.addItems(sorted(factors))
679
680    def createDefaultDataset(self):
681        """
682        Generate default Dataset 1D/2D for the given model
683        """
684        # Create default datasets if no data passed
685        if self.is2D:
686            qmax = self.q_range_max/np.sqrt(2)
687            qstep = self.npts
688            self.logic.createDefault2dData(qmax, qstep, self.tab_id)
689            return
690        elif self.log_points:
691            qmin = -10.0 if self.q_range_min < 1.e-10 else np.log10(self.q_range_min)
692            qmax =  10.0 if self.q_range_max > 1.e10 else np.log10(self.q_range_max)
693            interval = np.logspace(start=qmin, stop=qmax, num=self.npts, endpoint=True, base=10.0)
694        else:
695            interval = np.linspace(start=self.q_range_min, stop=self.q_range_max,
696                    num=self.npts, endpoint=True)
697        self.logic.createDefault1dData(interval, self.tab_id)
698
699    def readCategoryInfo(self):
700        """
701        Reads the categories in from file
702        """
703        self.master_category_dict = defaultdict(list)
704        self.by_model_dict = defaultdict(list)
705        self.model_enabled_dict = defaultdict(bool)
706
707        categorization_file = CategoryInstaller.get_user_file()
708        if not os.path.isfile(categorization_file):
709            categorization_file = CategoryInstaller.get_default_file()
710        with open(categorization_file, 'rb') as cat_file:
711            self.master_category_dict = json.load(cat_file)
712            self.regenerateModelDict()
713
714        # Load the model dict
715        models = load_standard_models()
716        for model in models:
717            self.models[model.name] = model
718
719    def regenerateModelDict(self):
720        """
721        Regenerates self.by_model_dict which has each model name as the
722        key and the list of categories belonging to that model
723        along with the enabled mapping
724        """
725        self.by_model_dict = defaultdict(list)
726        for category in self.master_category_dict:
727            for (model, enabled) in self.master_category_dict[category]:
728                self.by_model_dict[model].append(category)
729                self.model_enabled_dict[model] = enabled
730
731    def addBackgroundToModel(self, model):
732        """
733        Adds background parameter with default values to the model
734        """
735        assert isinstance(model, QtGui.QStandardItemModel)
736        checked_list = ['background', '0.001', '-inf', 'inf', '1/cm']
737        FittingUtilities.addCheckedListToModel(model, checked_list)
738        last_row = model.rowCount()-1
739        model.item(last_row, 0).setEditable(False)
740        model.item(last_row, 4).setEditable(False)
741
742    def addScaleToModel(self, model):
743        """
744        Adds scale parameter with default values to the model
745        """
746        assert isinstance(model, QtGui.QStandardItemModel)
747        checked_list = ['scale', '1.0', '0.0', 'inf', '']
748        FittingUtilities.addCheckedListToModel(model, checked_list)
749        last_row = model.rowCount()-1
750        model.item(last_row, 0).setEditable(False)
751        model.item(last_row, 4).setEditable(False)
752
753    def addWeightingToData(self, data):
754        """
755        Adds weighting contribution to fitting data
756        #"""
757        # Send original data for weighting
758        weight = get_weight(data=data, is2d=self.is2D, flag=self.weighting)
759        update_module = data.err_data if self.is2D else data.dy
760        update_module = weight
761
762    def updateQRange(self):
763        """
764        Updates Q Range display
765        """
766        if self.data_is_loaded:
767            self.q_range_min, self.q_range_max, self.npts = self.logic.computeDataRange()
768        # set Q range labels on the main tab
769        self.lblMinRangeDef.setText(str(self.q_range_min))
770        self.lblMaxRangeDef.setText(str(self.q_range_max))
771        # set Q range labels on the options tab
772        self.options_widget.updateQRange(self.q_range_min, self.q_range_max, self.npts)
773
774    def SASModelToQModel(self, model_name, structure_factor=None):
775        """
776        Setting model parameters into table based on selected category
777        """
778        # TODO - modify for structure factor-only choice
779
780        # Crete/overwrite model items
781        self._model_model.clear()
782
783        kernel_module = generate.load_kernel_module(model_name)
784        self.model_parameters = modelinfo.make_parameter_table(getattr(kernel_module, 'parameters', []))
785
786        # Instantiate the current sasmodel
787        self.kernel_module = self.models[model_name]()
788
789        # Explicitly add scale and background with default values
790        self.addScaleToModel(self._model_model)
791        self.addBackgroundToModel(self._model_model)
792
793        # Update the QModel
794        new_rows = FittingUtilities.addParametersToModel(self.model_parameters, self.is2D)
795        for row in new_rows:
796            self._model_model.appendRow(row)
797        # Update the counter used for multishell display
798        self._last_model_row = self._model_model.rowCount()
799
800        FittingUtilities.addHeadersToModel(self._model_model)
801
802        # Add structure factor
803        if structure_factor is not None and structure_factor != "None":
804            structure_module = generate.load_kernel_module(structure_factor)
805            structure_parameters = modelinfo.make_parameter_table(getattr(structure_module, 'parameters', []))
806            new_rows = FittingUtilities.addSimpleParametersToModel(structure_parameters, self.is2D)
807            for row in new_rows:
808                self._model_model.appendRow(row)
809            # Update the counter used for multishell display
810            self._last_model_row = self._model_model.rowCount()
811        else:
812            self.addStructureFactor()
813
814        # Multishell models need additional treatment
815        self.addExtraShells()
816
817        # Add polydispersity to the model
818        self.setPolyModel()
819        # Add magnetic parameters to the model
820        self.setMagneticModel()
821
822        # Adjust the table cells width
823        self.lstParams.resizeColumnToContents(0)
824        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
825
826        # Now we claim the model has been loaded
827        self.model_is_loaded = True
828
829        # Update Q Ranges
830        self.updateQRange()
831
832    def updateParamsFromModel(self, item):
833        """
834        Callback method for updating the sasmodel parameters with the GUI values
835        """
836        model_column = item.column()
837
838        if model_column == 0:
839            self.checkboxSelected(item)
840            self.cmdFit.setEnabled(self.parameters_to_fit != [] and self.logic.data_is_loaded)
841            return
842
843        model_row = item.row()
844        name_index = self._model_model.index(model_row, 0)
845
846        # Extract changed value. Assumes proper validation by QValidator/Delegate
847        # TODO: disable model update for uneditable cells/columns
848        try:
849            value = float(item.text())
850        except ValueError:
851            # Unparsable field
852            return
853        parameter_name = str(self._model_model.data(name_index).toPyObject()) # sld, background etc.
854        property_name = str(self._model_model.headerData(1, model_column).toPyObject()) # Value, min, max, etc.
855
856        self.kernel_module.params[parameter_name] = value
857
858        # min/max to be changed in self.kernel_module.details[parameter_name] = ['Ang', 0.0, inf]
859        # magnetic params in self.kernel_module.details['M0:parameter_name'] = value
860        # multishell params in self.kernel_module.details[??] = value
861
862        # Force the chart update when actual parameters changed
863        if model_column == 1:
864            self.recalculatePlotData()
865
866    def checkboxSelected(self, item):
867        # Assure we're dealing with checkboxes
868        if not item.isCheckable():
869            return
870        status = item.checkState()
871
872        def isChecked(row):
873            return self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked
874
875        def isCheckable(row):
876            return self._model_model.item(row, 0).isCheckable()
877
878        # If multiple rows selected - toggle all of them, filtering uncheckable
879        rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() if isCheckable(s.row())]
880
881        # Switch off signaling from the model to avoid recursion
882        self._model_model.blockSignals(True)
883        # Convert to proper indices and set requested enablement
884        items = [self._model_model.item(row, 0).setCheckState(status) for row in rows]
885        self._model_model.blockSignals(False)
886
887        # update the list of parameters to fit
888        self.parameters_to_fit = [str(self._model_model.item(row_index, 0).text())
889                                  for row_index in xrange(self._model_model.rowCount())
890                                  if isChecked(row_index)]
891
892    def nameForFittedData(self, name):
893        """
894        Generate name for the current fit
895        """
896        if self.is2D:
897            name += "2d"
898        name = "M%i [%s]" % (self.tab_id, name)
899        return name
900
901    def createNewIndex(self, fitted_data):
902        """
903        Create a model or theory index with passed Data1D/Data2D
904        """
905        if self.data_is_loaded:
906            if not fitted_data.name:
907                name = self.nameForFittedData(self.data.filename)
908                fitted_data.title = name
909                fitted_data.name = name
910                fitted_data.filename = name
911                fitted_data.symbol = "Line"
912            self.updateModelIndex(fitted_data)
913        else:
914            name = self.nameForFittedData(self.kernel_module.name)
915            fitted_data.title = name
916            fitted_data.name = name
917            fitted_data.filename = name
918            fitted_data.symbol = "Line"
919            self.createTheoryIndex(fitted_data)
920
921    def updateModelIndex(self, fitted_data):
922        """
923        Update a QStandardModelIndex containing model data
924        """
925        if fitted_data.name is None:
926            name = self.nameForFittedData(self.logic.data.filename)
927            fitted_data.title = name
928            fitted_data.name = name
929        else:
930            name = fitted_data.name
931        # Make this a line if no other defined
932        if hasattr(fitted_data, 'symbol') and fitted_data.symbol is None:
933            fitted_data.symbol = 'Line'
934        # Notify the GUI manager so it can update the main model in DataExplorer
935        GuiUtils.updateModelItemWithPlot(self._index, QtCore.QVariant(fitted_data), name)
936
937    def createTheoryIndex(self, fitted_data):
938        """
939        Create a QStandardModelIndex containing model data
940        """
941        if fitted_data.name is None:
942            name = self.nameForFittedData(self.kernel_module.name)
943            fitted_data.title = name
944            fitted_data.name = name
945            fitted_data.filename = name
946        else:
947            name = fitted_data.name
948        # Notify the GUI manager so it can create the theory model in DataExplorer
949        new_item = GuiUtils.createModelItemWithPlot(QtCore.QVariant(fitted_data), name=name)
950        self.communicate.updateTheoryFromPerspectiveSignal.emit(new_item)
951
952    def methodCalculateForData(self):
953        '''return the method for data calculation'''
954        return Calc1D if isinstance(self.data, Data1D) else Calc2D
955
956    def methodCompleteForData(self):
957        '''return the method for result parsin on calc complete '''
958        return self.complete1D if isinstance(self.data, Data1D) else self.complete2D
959
960    def calculateQGridForModel(self):
961        """
962        Prepare the fitting data object, based on current ModelModel
963        """
964        if self.kernel_module is None:
965            return
966        # Awful API to a backend method.
967        method = self.methodCalculateForData()(data=self.data,
968                              model=self.kernel_module,
969                              page_id=0,
970                              qmin=self.q_range_min,
971                              qmax=self.q_range_max,
972                              smearer=None,
973                              state=None,
974                              weight=None,
975                              fid=None,
976                              toggle_mode_on=False,
977                              completefn=None,
978                              update_chisqr=True,
979                              exception_handler=self.calcException,
980                              source=None)
981
982        calc_thread = threads.deferToThread(method.compute)
983        calc_thread.addCallback(self.methodCompleteForData())
984
985    def complete1D(self, return_data):
986        """
987        Plot the current 1D data
988        """
989        fitted_data = self.logic.new1DPlot(return_data, self.tab_id)
990        self.calculateResiduals(fitted_data)
991        self.model_data = fitted_data
992
993    def complete2D(self, return_data):
994        """
995        Plot the current 2D data
996        """
997        fitted_data = self.logic.new2DPlot(return_data)
998        self.calculateResiduals(fitted_data)
999        self.model_data = fitted_data
1000
1001    def calculateResiduals(self, fitted_data):
1002        """
1003        Calculate and print Chi2 and display chart of residuals
1004        """
1005        # Create a new index for holding data
1006        fitted_data.symbol = "Line"
1007        self.createNewIndex(fitted_data)
1008        # Calculate difference between return_data and logic.data
1009        self.chi2 = FittingUtilities.calculateChi2(fitted_data, self.logic.data)
1010        # Update the control
1011        chi2_repr = "---" if self.chi2 is None else GuiUtils.formatNumber(self.chi2, high=True)
1012        self.lblChi2Value.setText(chi2_repr)
1013
1014        self.communicate.plotUpdateSignal.emit([fitted_data])
1015
1016        # Plot residuals if actual data
1017        if self.data_is_loaded:
1018            residuals_plot = FittingUtilities.plotResiduals(self.data, fitted_data)
1019            residuals_plot.id = "Residual " + residuals_plot.id
1020            self.createNewIndex(residuals_plot)
1021            self.communicate.plotUpdateSignal.emit([residuals_plot])
1022
1023    def calcException(self, etype, value, tb):
1024        """
1025        Something horrible happened in the deferred.
1026        """
1027        logging.error("".join(traceback.format_exception(etype, value, tb)))
1028
1029    def setTableProperties(self, table):
1030        """
1031        Setting table properties
1032        """
1033        # Table properties
1034        table.verticalHeader().setVisible(False)
1035        table.setAlternatingRowColors(True)
1036        table.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
1037        table.setSelectionBehavior(QtGui.QAbstractItemView.SelectRows)
1038        table.resizeColumnsToContents()
1039
1040        # Header
1041        header = table.horizontalHeader()
1042        header.setResizeMode(QtGui.QHeaderView.ResizeToContents)
1043
1044        header.ResizeMode(QtGui.QHeaderView.Interactive)
1045        # Resize column 0 and 6 to content
1046        header.setResizeMode(0, QtGui.QHeaderView.ResizeToContents)
1047        header.setResizeMode(6, QtGui.QHeaderView.ResizeToContents)
1048
1049    def setPolyModel(self):
1050        """
1051        Set polydispersity values
1052        """
1053        if not self.model_parameters:
1054            return
1055        self._poly_model.clear()
1056        for row, param in enumerate(self.model_parameters.form_volume_parameters):
1057            # Counters should not be included
1058            if not param.polydisperse:
1059                continue
1060
1061            # Potential multishell params
1062            checked_list = ["Distribution of "+param.name, str(param.default),
1063                            str(param.limits[0]), str(param.limits[1]),
1064                            "35", "3", ""]
1065            FittingUtilities.addCheckedListToModel(self._poly_model, checked_list)
1066
1067            #TODO: Need to find cleaner way to input functions
1068            func = QtGui.QComboBox()
1069            func.addItems(['rectangle', 'array', 'lognormal', 'gaussian', 'schulz',])
1070            func_index = self.lstPoly.model().index(row, 6)
1071            self.lstPoly.setIndexWidget(func_index, func)
1072
1073        FittingUtilities.addPolyHeadersToModel(self._poly_model)
1074
1075    def setMagneticModel(self):
1076        """
1077        Set magnetism values on model
1078        """
1079        if not self.model_parameters:
1080            return
1081        self._magnet_model.clear()
1082        for param in self.model_parameters.call_parameters:
1083            if param.type != "magnetic":
1084                continue
1085            checked_list = [param.name,
1086                            str(param.default),
1087                            str(param.limits[0]),
1088                            str(param.limits[1]),
1089                            param.units]
1090            FittingUtilities.addCheckedListToModel(self._magnet_model, checked_list)
1091
1092        FittingUtilities.addHeadersToModel(self._magnet_model)
1093
1094    def addStructureFactor(self):
1095        """
1096        Add structure factors to the list of parameters
1097        """
1098        if self.kernel_module.is_form_factor:
1099            self.enableStructureCombo()
1100        else:
1101            self.disableStructureCombo()
1102
1103    def addExtraShells(self):
1104        """
1105        Add a combobox for multiple shell display
1106        """
1107        param_name, param_length = FittingUtilities.getMultiplicity(self.model_parameters)
1108
1109        if param_length == 0:
1110            return
1111
1112        # cell 1: variable name
1113        item1 = QtGui.QStandardItem(param_name)
1114
1115        func = QtGui.QComboBox()
1116        # Available range of shells displayed in the combobox
1117        func.addItems([str(i) for i in xrange(param_length+1)])
1118
1119        # Respond to index change
1120        func.currentIndexChanged.connect(self.modifyShellsInList)
1121
1122        # cell 2: combobox
1123        item2 = QtGui.QStandardItem()
1124        self._model_model.appendRow([item1, item2])
1125
1126        # Beautify the row:  span columns 2-4
1127        shell_row = self._model_model.rowCount()
1128        shell_index = self._model_model.index(shell_row-1, 1)
1129
1130        self.lstParams.setIndexWidget(shell_index, func)
1131        self._last_model_row = self._model_model.rowCount()
1132
1133        # Set the index to the state-kept value
1134        func.setCurrentIndex(self.current_shell_displayed
1135                             if self.current_shell_displayed < func.count() else 0)
1136
1137    def modifyShellsInList(self, index):
1138        """
1139        Add/remove additional multishell parameters
1140        """
1141        # Find row location of the combobox
1142        last_row = self._last_model_row
1143        remove_rows = self._model_model.rowCount() - last_row
1144
1145        if remove_rows > 1:
1146            self._model_model.removeRows(last_row, remove_rows)
1147
1148        FittingUtilities.addShellsToModel(self.model_parameters, self._model_model, index)
1149        self.current_shell_displayed = index
1150
1151    def readFitPage(self, fp):
1152        """
1153        Read in state from a fitpage object and update GUI
1154        """
1155        assert isinstance(fp, FitPage)
1156        # Main tab info
1157        self.logic.data.filename = fp.filename
1158        self.data_is_loaded = fp.data_is_loaded
1159        self.chkPolydispersity.setCheckState(fp.is_polydisperse)
1160        self.chkMagnetism.setCheckState(fp.is_magnetic)
1161        self.chk2DView.setCheckState(fp.is2D)
1162
1163        # Update the comboboxes
1164        self.cbCategory.setCurrentIndex(self.cbCategory.findText(fp.current_category))
1165        self.cbModel.setCurrentIndex(self.cbModel.findText(fp.current_model))
1166        if fp.current_factor:
1167            self.cbStructureFactor.setCurrentIndex(self.cbStructureFactor.findText(fp.current_factor))
1168
1169        self.chi2 = fp.chi2
1170
1171        # Options tab
1172        self.q_range_min = fp.fit_options[fp.MIN_RANGE]
1173        self.q_range_max = fp.fit_options[fp.MAX_RANGE]
1174        self.npts = fp.fit_options[fp.NPTS]
1175        #fp.fit_options[fp.NPTS_FIT] = self.npts_fit
1176        self.log_points = fp.fit_options[fp.LOG_POINTS]
1177        self.weighting = fp.fit_options[fp.WEIGHTING]
1178
1179        # Models
1180        #self._model_model = fp.model_model
1181        #self._poly_model = fp.poly_model
1182        #self._magnet_model = fp.magnetism_model
1183
1184        # Resolution tab
1185        smearing = fp.smearing_options[fp.SMEARING_OPTION]
1186        accuracy = fp.smearing_options[fp.SMEARING_ACCURACY]
1187        smearing_min = fp.smearing_options[fp.SMEARING_MIN]
1188        smearing_max = fp.smearing_options[fp.SMEARING_MAX]
1189        self.smearing_widget.setState(smearing, accuracy, smearing_min, smearing_max)
1190
1191        # TODO: add polidyspersity and magnetism
1192
1193    def saveToFitPage(self, fp):
1194        """
1195        Write current state to the given fitpage
1196        """
1197        assert isinstance(fp, FitPage)
1198
1199        # Main tab info
1200        fp.filename = self.logic.data.filename
1201        fp.data_is_loaded = self.data_is_loaded
1202        fp.is_polydisperse = self.chkPolydispersity.isChecked()
1203        fp.is_magnetic = self.chkMagnetism.isChecked()
1204        fp.is2D = self.chk2DView.isChecked()
1205        fp.data = self.data
1206
1207        # Use current models - they contain all the required parameters
1208        fp.model_model = self._model_model
1209        fp.poly_model = self._poly_model
1210        fp.magnetism_model = self._magnet_model
1211
1212        if self.cbCategory.currentIndex() != 0:
1213            fp.current_category = str(self.cbCategory.currentText())
1214            fp.current_model = str(self.cbModel.currentText())
1215
1216        if self.cbStructureFactor.isEnabled() and self.cbStructureFactor.currentIndex() != 0:
1217            fp.current_factor = str(self.cbStructureFactor.currentText())
1218        else:
1219            fp.current_factor = ''
1220
1221        fp.chi2 = self.chi2
1222        fp.parameters_to_fit = self.parameters_to_fit
1223
1224        # Options tab
1225        fp.fit_options[fp.MIN_RANGE] = self.q_range_min
1226        fp.fit_options[fp.MAX_RANGE] = self.q_range_max
1227        fp.fit_options[fp.NPTS] = self.npts
1228        #fp.fit_options[fp.NPTS_FIT] = self.npts_fit
1229        fp.fit_options[fp.LOG_POINTS] = self.log_points
1230        fp.fit_options[fp.WEIGHTING] = self.weighting
1231
1232        # Resolution tab
1233        smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state()
1234        fp.smearing_options[fp.SMEARING_OPTION] = smearing
1235        fp.smearing_options[fp.SMEARING_ACCURACY] = accuracy
1236        fp.smearing_options[fp.SMEARING_MIN] = smearing_min
1237        fp.smearing_options[fp.SMEARING_MAX] = smearing_max
1238
1239        # TODO: add polidyspersity and magnetism
1240
1241    def currentState(self):
1242        """
1243        Return fit page with current state
1244        """
1245        new_page = FitPage()
1246        self.saveToFitPage(new_page)
1247
1248        return new_page
1249
1250    def pushFitPage(self, new_page):
1251        """
1252        Add a new fit page object with current state
1253        """
1254        #page_stack.append(new_page)
1255        pass
1256
1257    def popFitPage(self):
1258        """
1259        Remove top fit page from stack
1260        """
1261        #if page_stack:
1262        #    page_stack.pop()
1263        pass
1264
Note: See TracBrowser for help on using the repository browser.