source: sasview/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @ 61a92d4

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalc
Last change on this file since 61a92d4 was 61a92d4, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Minor UI resizing.
Refactored ObjectFactory?

  • Property mode set to 100644
File size: 38.1 KB
Line 
1import sys
2import json
3import os
4import numpy as np
5from collections import defaultdict
6from itertools import izip
7
8import logging
9import traceback
10from twisted.internet import threads
11
12from PyQt4 import QtGui
13from PyQt4 import QtCore
14
15from sasmodels import generate
16from sasmodels import modelinfo
17from sasmodels.sasview_model import load_standard_models
18from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit
19from sas.sasgui.perspectives.fitting.fit_thread import FitThread
20
21from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
22from sas.sasgui.guiframe.dataFitting import Data1D
23from sas.sasgui.guiframe.dataFitting import Data2D
24import sas.qtgui.Utilities.GuiUtils as GuiUtils
25from sas.sasgui.perspectives.fitting.model_thread import Calc1D
26from sas.sasgui.perspectives.fitting.model_thread import Calc2D
27from sas.sasgui.perspectives.fitting.utils import get_weight
28
29from UI.FittingWidgetUI import Ui_FittingWidgetUI
30from sas.qtgui.Perspectives.Fitting.FittingLogic import FittingLogic
31from sas.qtgui.Perspectives.Fitting import FittingUtilities
32from SmearingWidget import SmearingWidget
33from OptionsWidget import OptionsWidget
34
35TAB_MAGNETISM = 4
36TAB_POLY = 3
37CATEGORY_DEFAULT = "Choose category..."
38CATEGORY_STRUCTURE = "Structure Factor"
39STRUCTURE_DEFAULT = "None"
40
41class FittingWidget(QtGui.QWidget, Ui_FittingWidgetUI):
42    """
43    Main widget for selecting form and structure factor models
44    """
45    def __init__(self, parent=None, data=None, id=1):
46
47        super(FittingWidget, self).__init__()
48
49        # Necessary globals
50        self.parent = parent
51        # SasModel is loaded
52        self.model_is_loaded = False
53        # Data[12]D passed and set
54        self.data_is_loaded = False
55        # Current SasModel in view
56        self.kernel_module = None
57        # Current SasModel view dimension
58        self.is2D = False
59        # Current SasModel is multishell
60        self.model_has_shells = False
61        # Utility variable to enable unselectable option in category combobox
62        self._previous_category_index = 0
63        # Utility variable for multishell display
64        self._last_model_row = 0
65        # Dictionary of {model name: model class} for the current category
66        self.models = {}
67        # Parameters to fit
68        self.parameters_to_fit = None
69        # Fit options
70        self.q_range_min = 0.005
71        self.q_range_max = 0.1
72        self.npts = 25
73        self.log_points = False
74        self.weighting = 0
75
76        # Which tab is this widget displayed in?
77        self.tab_id = id
78
79        # Which shell is being currently displayed?
80        self.current_shell_displayed = 0
81        self.has_error_column = False
82
83        # Main Data[12]D holder
84        self.logic = FittingLogic(data=data)
85
86        # Main GUI setup up
87        self.setupUi(self)
88        self.setWindowTitle("Fitting")
89        self.communicate = self.parent.communicate
90
91        # Options widget
92        layout = QtGui.QGridLayout()
93        self.options_widget = OptionsWidget(self, self.logic)
94        layout.addWidget(self.options_widget) 
95        self.tabOptions.setLayout(layout)
96
97        # Smearing widget
98        layout = QtGui.QGridLayout()
99        self.smearing_widget = SmearingWidget(self)
100        layout.addWidget(self.smearing_widget) 
101        self.tabResolution.setLayout(layout)
102
103        # Define bold font for use in various controls
104        self.boldFont=QtGui.QFont()
105        self.boldFont.setBold(True)
106
107        # Set data label
108        self.label.setFont(self.boldFont)
109        self.label.setText("No data loaded")
110        self.lblFilename.setText("")
111
112        # Set the main models
113        # We can't use a single model here, due to restrictions on flattening
114        # the model tree with subclassed QAbstractProxyModel...
115        self._model_model = QtGui.QStandardItemModel()
116        self._poly_model = QtGui.QStandardItemModel()
117        self._magnet_model = QtGui.QStandardItemModel()
118
119        # Param model displayed in param list
120        self.lstParams.setModel(self._model_model)
121        self.readCategoryInfo()
122        self.model_parameters = None
123        self.lstParams.setAlternatingRowColors(True)
124        stylesheet = """
125            QTreeView{
126                alternate-background-color: #f6fafb;
127                background: #e8f4fc;
128            }
129        """
130        self.lstParams.setStyleSheet(stylesheet)
131
132        # Poly model displayed in poly list
133        self.lstPoly.setModel(self._poly_model)
134        self.setPolyModel()
135        self.setTableProperties(self.lstPoly)
136
137        # Magnetism model displayed in magnetism list
138        self.lstMagnetic.setModel(self._magnet_model)
139        self.setMagneticModel()
140        self.setTableProperties(self.lstMagnetic)
141
142        # Defaults for the structure factors
143        self.setDefaultStructureCombo()
144
145        # Make structure factor and model CBs disabled
146        self.disableModelCombo()
147        self.disableStructureCombo()
148
149        # Generate the category list for display
150        category_list = sorted(self.master_category_dict.keys())
151        self.cbCategory.addItem(CATEGORY_DEFAULT)
152        self.cbCategory.addItems(category_list)
153        self.cbCategory.addItem(CATEGORY_STRUCTURE)
154        self.cbCategory.setCurrentIndex(0)
155
156        # Connect signals to controls
157        self.initializeSignals()
158
159        # Initial control state
160        self.initializeControls()
161
162        self._index = None
163        if data is not None:
164            self.data = data
165
166    @property
167    def data(self):
168        return self.logic.data
169
170    @data.setter
171    def data(self, value):
172        """ data setter """
173        assert isinstance(value, QtGui.QStandardItem)
174        # _index contains the QIndex with data
175        self._index = value
176
177        # Update logics with data items
178        self.logic.data = GuiUtils.dataFromItem(value)
179
180        # Overwrite data type descriptor
181        self.is2D = True if isinstance(self.logic.data, Data2D) else False
182
183        self.data_is_loaded = True
184
185        # Enable/disable UI components
186        self.setEnablementOnDataLoad()
187
188    def setEnablementOnDataLoad(self):
189        """
190        Enable/disable various UI elements based on data loaded
191        """
192        # Tag along functionality
193        self.label.setText("Data loaded from: ")
194        self.lblFilename.setText(self.logic.data.filename)
195        self.updateQRange()
196        self.cmdFit.setEnabled(True)
197        # Switch off Data2D control
198        self.chk2DView.setEnabled(False)
199        self.chk2DView.setVisible(False)
200        self.chkMagnetism.setEnabled(True)
201        # Similarly on other tabs
202        self.options_widget.setEnablementOnDataLoad()
203
204        # Smearing tab
205        self.smearing_widget.updateSmearing(self.data)
206
207    def acceptsData(self):
208        """ Tells the caller this widget can accept new dataset """
209        return not self.data_is_loaded
210
211    def disableModelCombo(self):
212        """ Disable the combobox """
213        self.cbModel.setEnabled(False)
214        self.lblModel.setEnabled(False)
215
216    def enableModelCombo(self):
217        """ Enable the combobox """
218        self.cbModel.setEnabled(True)
219        self.lblModel.setEnabled(True)
220
221    def disableStructureCombo(self):
222        """ Disable the combobox """
223        self.cbStructureFactor.setEnabled(False)
224        self.lblStructure.setEnabled(False)
225
226    def enableStructureCombo(self):
227        """ Enable the combobox """
228        self.cbStructureFactor.setEnabled(True)
229        self.lblStructure.setEnabled(True)
230
231    def togglePoly(self, isChecked):
232        """ Enable/disable the polydispersity tab """
233        self.tabFitting.setTabEnabled(TAB_POLY, isChecked)
234
235    def toggleMagnetism(self, isChecked):
236        """ Enable/disable the magnetism tab """
237        self.tabFitting.setTabEnabled(TAB_MAGNETISM, isChecked)
238
239    def toggle2D(self, isChecked):
240        """ Enable/disable the controls dependent on 1D/2D data instance """
241        self.chkMagnetism.setEnabled(isChecked)
242        self.is2D = isChecked
243        # Reload the current model
244        if self.kernel_module:
245            self.onSelectModel()
246
247    def initializeControls(self):
248        """
249        Set initial control enablement
250        """
251        self.cmdFit.setEnabled(False)
252        self.cmdPlot.setEnabled(True)
253        self.options_widget.cmdComputePoints.setVisible(False) # probably redundant
254        self.chkPolydispersity.setEnabled(True)
255        self.chkPolydispersity.setCheckState(False)
256        self.chk2DView.setEnabled(True)
257        self.chk2DView.setCheckState(False)
258        self.chkMagnetism.setEnabled(False)
259        self.chkMagnetism.setCheckState(False)
260        # Tabs
261        self.tabFitting.setTabEnabled(TAB_POLY, False)
262        self.tabFitting.setTabEnabled(TAB_MAGNETISM, False)
263        self.lblChi2Value.setText("---")
264        # Smearing tab
265        self.smearing_widget.updateSmearing(self.data)
266        # Line edits in the option tab
267        self.updateQRange()
268
269    def initializeSignals(self):
270        """
271        Connect GUI element signals
272        """
273        # Comboboxes
274        self.cbStructureFactor.currentIndexChanged.connect(self.onSelectStructureFactor)
275        self.cbCategory.currentIndexChanged.connect(self.onSelectCategory)
276        self.cbModel.currentIndexChanged.connect(self.onSelectModel)
277        # Checkboxes
278        self.chk2DView.toggled.connect(self.toggle2D)
279        self.chkPolydispersity.toggled.connect(self.togglePoly)
280        self.chkMagnetism.toggled.connect(self.toggleMagnetism)
281        # Buttons
282        self.cmdFit.clicked.connect(self.onFit)
283        self.cmdPlot.clicked.connect(self.onPlot)
284
285        # Respond to change in parameters from the UI
286        self._model_model.itemChanged.connect(self.updateParamsFromModel)
287        self._poly_model.itemChanged.connect(self.onPolyModelChange)
288        # TODO after the poly_model prototype accepted
289        #self._magnet_model.itemChanged.connect(self.onMagneticModelChange)
290
291        # Signals from separate tabs asking for replot
292        self.options_widget.plot_signal.connect(self.onOptionsUpdate)
293
294    def onSelectModel(self):
295        """
296        Respond to select Model from list event
297        """
298        model = str(self.cbModel.currentText())
299
300        # Reset structure factor
301        self.cbStructureFactor.setCurrentIndex(0)
302
303        # Reset parameters to fit
304        self.parameters_to_fit = None
305        self.has_error_column = False
306
307        # SasModel -> QModel
308        self.SASModelToQModel(model)
309
310        if self.data_is_loaded:
311            self.calculateQGridForModel()
312        else:
313            # Create default datasets if no data passed
314            self.createDefaultDataset()
315
316    def onSelectStructureFactor(self):
317        """
318        Select Structure Factor from list
319        """
320        model = str(self.cbModel.currentText())
321        category = str(self.cbCategory.currentText())
322        structure = str(self.cbStructureFactor.currentText())
323        if category == CATEGORY_STRUCTURE:
324            model = None
325        self.SASModelToQModel(model, structure_factor=structure)
326
327    def onSelectCategory(self):
328        """
329        Select Category from list
330        """
331        category = str(self.cbCategory.currentText())
332        # Check if the user chose "Choose category entry"
333        if category == CATEGORY_DEFAULT:
334            # if the previous category was not the default, keep it.
335            # Otherwise, just return
336            if self._previous_category_index != 0:
337                # We need to block signals, or else state changes on perceived unchanged conditions
338                self.cbCategory.blockSignals(True)
339                self.cbCategory.setCurrentIndex(self._previous_category_index)
340                self.cbCategory.blockSignals(False)
341            return
342
343        if category == CATEGORY_STRUCTURE:
344            self.disableModelCombo()
345            self.enableStructureCombo()
346            self._model_model.clear()
347            return
348
349        # Safely clear and enable the model combo
350        self.cbModel.blockSignals(True)
351        self.cbModel.clear()
352        self.cbModel.blockSignals(False)
353        self.enableModelCombo()
354        self.disableStructureCombo()
355
356        self._previous_category_index = self.cbCategory.currentIndex()
357        # Retrieve the list of models
358        model_list = self.master_category_dict[category]
359        models = []
360        # Populate the models combobox
361        self.cbModel.addItems(sorted([model for (model, _) in model_list]))
362
363    def onPolyModelChange(self, item):
364        """
365        Callback method for updating the main model and sasmodel
366        parameters with the GUI values in the polydispersity view
367        """
368        model_column = item.column()
369        model_row = item.row()
370        name_index = self._poly_model.index(model_row, 0)
371        # Extract changed value. Assumes proper validation by QValidator/Delegate
372        # Checkbox in column 0
373        if model_column == 0:
374            value = item.checkState()
375        else:
376            try:
377                value = float(item.text())
378            except ValueError:
379                # Can't be converted properly, bring back the old value and exit
380                return
381
382        parameter_name = str(self._poly_model.data(name_index).toPyObject()) # "distribution of sld" etc.
383        if "Distribution of" in parameter_name:
384            parameter_name = parameter_name[16:]
385        property_name = str(self._poly_model.headerData(model_column, 1).toPyObject()) # Value, min, max, etc.
386        # print "%s(%s) => %d" % (parameter_name, property_name, value)
387
388        # Update the sasmodel
389        #self.kernel_module.params[parameter_name] = value
390
391        # Reload the main model - may not be required if no variable is shown in main view
392        #model = str(self.cbModel.currentText())
393        #self.SASModelToQModel(model)
394
395        pass # debug anchor
396
397    def onFit(self):
398        """
399        Perform fitting on the current data
400        """
401        fitter = Fit()
402
403        # Data going in
404        data = self.logic.data
405        model = self.kernel_module
406        qmin = self.q_range_min
407        qmax = self.q_range_max
408        params_to_fit = self.parameters_to_fit
409
410        # Potential weights added directly to data
411        self.addWeightingToData(data)
412
413        # Potential smearing added
414        # Remember that smearing_min/max can be None ->
415        # deal with it until Python gets discriminated unions
416        smearing, accuracy, smearing_min, smearing_max = self.smearing_widget.state()
417
418        # These should be updating somehow?
419        fit_id = 0
420        constraints = []
421        smearer = None
422        page_id = [210]
423        handler = None
424        batch_inputs = {}
425        batch_outputs = {}
426        list_page_id = [page_id]
427        #---------------------------------
428
429        # Parameterize the fitter
430        fitter.set_model(model, fit_id, params_to_fit, data=data,
431                         constraints=constraints)
432        fitter.set_data(data=data, id=fit_id, smearer=smearer, qmin=qmin,
433                        qmax=qmax)
434        fitter.select_problem_for_fit(id=fit_id, value=1)
435
436        fitter.fitter_id = page_id
437
438        # Create the fitting thread, based on the fitter
439        calc_fit = FitThread(handler=handler,
440                             fn=[fitter],
441                             batch_inputs=batch_inputs,
442                             batch_outputs=batch_outputs,
443                             page_id=list_page_id,
444                             updatefn=self.updateFit,
445                             completefn=None)
446
447        # start the trhrhread
448        calc_thread = threads.deferToThread(calc_fit.compute)
449        calc_thread.addCallback(self.fitComplete)
450
451        #disable the Fit button
452        self.cmdFit.setText('Calculating...')
453        self.communicate.statusBarUpdateSignal.emit('Fitting started...')
454        self.cmdFit.setEnabled(False)
455
456    def updateFit(self):
457        """
458        """
459        print "UPDATE FIT"
460        pass
461
462    def fitComplete(self, result):
463        """
464        Receive and display fitting results
465        "result" is a tuple of actual result list and the fit time in seconds
466        """
467        #re-enable the Fit button
468        self.cmdFit.setText("Fit")
469        self.cmdFit.setEnabled(True)
470
471        assert result is not None
472
473        res_list = result[0]
474        res = res_list[0]
475        if res.fitness is None or \
476            not np.isfinite(res.fitness) or \
477            np.any(res.pvec == None) or \
478            not np.all(np.isfinite(res.pvec)):
479            msg = "Fitting did not converge!!!"
480            self.communicate.statusBarUpdateSignal.emit(msg)
481            logging.error(msg)
482            return
483
484        elapsed = result[1]
485        msg = "Fitting completed successfully in: %s s.\n" % GuiUtils.formatNumber(elapsed)
486        self.communicate.statusBarUpdateSignal.emit(msg)
487
488        fitness = res.fitness
489        param_list = res.param_list
490        param_values = res.pvec
491        param_stderr = res.stderr
492        params_and_errors = zip(param_values, param_stderr)
493        param_dict = dict(izip(param_list, params_and_errors))
494
495        # Dictionary of fitted parameter: value, error
496        # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)}
497        self.updateModelFromList(param_dict)
498
499        # update charts
500        self.onPlot()
501
502        # Read only value - we can get away by just printing it here
503        chi2_repr = GuiUtils.formatNumber(fitness, high=True)
504        self.lblChi2Value.setText(chi2_repr)
505
506    def iterateOverModel(self, func):
507        """
508        Take func and throw it inside the model row loop
509        """
510        #assert isinstance(func, function)
511        for row_i in xrange(self._model_model.rowCount()):
512            func(row_i)
513
514    def updateModelFromList(self, param_dict):
515        """
516        Update the model with new parameters, create the errors column
517        """
518        assert isinstance(param_dict, dict)
519        if not dict:
520            return
521
522        def updateFittedValues(row_i):
523            # Utility function for main model update
524            # internal so can use closure for param_dict
525            param_name = str(self._model_model.item(row_i, 0).text())
526            if param_name not in param_dict.keys():
527                return
528            # modify the param value
529            param_repr = GuiUtils.formatNumber(param_dict[param_name][0], high=True)
530            self._model_model.item(row_i, 1).setText(param_repr)
531            if self.has_error_column:
532                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
533                self._model_model.item(row_i, 2).setText(error_repr)
534
535        def createErrorColumn(row_i):
536            # Utility function for error column update
537            item = QtGui.QStandardItem()
538            for param_name in param_dict.keys():
539                if str(self._model_model.item(row_i, 0).text()) != param_name:
540                    continue
541                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True)
542                item.setText(error_repr)
543            error_column.append(item)
544
545        self.iterateOverModel(updateFittedValues)
546
547        if self.has_error_column:
548            return
549
550        error_column = []
551        self.iterateOverModel(createErrorColumn)
552
553        # switch off reponse to model change
554        self._model_model.blockSignals(True)
555        self._model_model.insertColumn(2, error_column)
556        self._model_model.blockSignals(False)
557        FittingUtilities.addErrorHeadersToModel(self._model_model)
558        # Adjust the table cells width.
559        # TODO: find a way to dynamically adjust column width while resized expanding
560        self.lstParams.resizeColumnToContents(0)
561        self.lstParams.resizeColumnToContents(4)
562        self.lstParams.resizeColumnToContents(5)
563        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
564
565        self.has_error_column = True
566
567    def onPlot(self):
568        """
569        Plot the current set of data
570        """
571        if not self.data_is_loaded:
572            self.createDefaultDataset()
573        self.calculateQGridForModel()
574
575    def onOptionsUpdate(self):
576        """
577        Update local option values and replot
578        """
579        self.q_range_min, self.q_range_max, self.npts, self.log_points, self.weighting = \
580            self.options_widget.state()
581        # set Q range labels on the main tab
582        self.lblMinRangeDef.setText(str(self.q_range_min))
583        self.lblMaxRangeDef.setText(str(self.q_range_max))
584        self.onPlot()
585
586    def setDefaultStructureCombo(self):
587        """
588        Fill in the structure factors combo box with defaults
589        """
590        structure_factor_list = self.master_category_dict.pop(CATEGORY_STRUCTURE)
591        factors = [factor[0] for factor in structure_factor_list]
592        factors.insert(0, STRUCTURE_DEFAULT)
593        self.cbStructureFactor.clear()
594        self.cbStructureFactor.addItems(sorted(factors))
595
596    def createDefaultDataset(self):
597        """
598        Generate default Dataset 1D/2D for the given model
599        """
600        # Create default datasets if no data passed
601        if self.is2D:
602            qmax = self.q_range_max/np.sqrt(2)
603            qstep = self.npts
604            self.logic.createDefault2dData(qmax, qstep, self.tab_id)
605            return
606        elif self.log_points:
607            qmin = -10.0 if self.q_range_min < 1.e-10 else np.log10(self.q_range_min)
608            qmax =  10.0 if self.q_range_max > 1.e10 else np.log10(self.q_range_max)
609            interval = np.logspace(start=qmin, stop=qmax, num=self.npts, endpoint=True, base=10.0)
610        else:
611            interval = np.linspace(start=self.q_range_min, stop=self.q_range_max,
612                    num=self.npts, endpoint=True)
613        self.logic.createDefault1dData(interval, self.tab_id)
614
615    def readCategoryInfo(self):
616        """
617        Reads the categories in from file
618        """
619        self.master_category_dict = defaultdict(list)
620        self.by_model_dict = defaultdict(list)
621        self.model_enabled_dict = defaultdict(bool)
622
623        categorization_file = CategoryInstaller.get_user_file()
624        if not os.path.isfile(categorization_file):
625            categorization_file = CategoryInstaller.get_default_file()
626        with open(categorization_file, 'rb') as cat_file:
627            self.master_category_dict = json.load(cat_file)
628            self.regenerateModelDict()
629
630        # Load the model dict
631        models = load_standard_models()
632        for model in models:
633            self.models[model.name] = model
634
635    def regenerateModelDict(self):
636        """
637        Regenerates self.by_model_dict which has each model name as the
638        key and the list of categories belonging to that model
639        along with the enabled mapping
640        """
641        self.by_model_dict = defaultdict(list)
642        for category in self.master_category_dict:
643            for (model, enabled) in self.master_category_dict[category]:
644                self.by_model_dict[model].append(category)
645                self.model_enabled_dict[model] = enabled
646
647    def addBackgroundToModel(self, model):
648        """
649        Adds background parameter with default values to the model
650        """
651        assert isinstance(model, QtGui.QStandardItemModel)
652        checked_list = ['background', '0.001', '-inf', 'inf', '1/cm']
653        FittingUtilities.addCheckedListToModel(model, checked_list)
654
655    def addScaleToModel(self, model):
656        """
657        Adds scale parameter with default values to the model
658        """
659        assert isinstance(model, QtGui.QStandardItemModel)
660        checked_list = ['scale', '1.0', '0.0', 'inf', '']
661        FittingUtilities.addCheckedListToModel(model, checked_list)
662
663    def addWeightingToData(self, data):
664        """
665        Adds weighting contribution to fitting data
666        #"""
667        # Send original data for weighting
668        weight = get_weight(data=data, is2d=self.is2D, flag=self.weighting)
669        update_module = data.err_data if self.is2D else data.dy
670        update_module = weight
671
672    def updateQRange(self):
673        """
674        Updates Q Range display
675        """
676        if self.data_is_loaded:
677            self.q_range_min, self.q_range_max, self.npts = self.logic.computeDataRange()
678        # set Q range labels on the main tab
679        self.lblMinRangeDef.setText(str(self.q_range_min))
680        self.lblMaxRangeDef.setText(str(self.q_range_max))
681        # set Q range labels on the options tab
682        self.options_widget.updateQRange(self.q_range_min, self.q_range_max, self.npts)
683
684    def SASModelToQModel(self, model_name, structure_factor=None):
685        """
686        Setting model parameters into table based on selected category
687        """
688        # TODO - modify for structure factor-only choice
689
690        # Crete/overwrite model items
691        self._model_model.clear()
692
693        kernel_module = generate.load_kernel_module(model_name)
694        self.model_parameters = modelinfo.make_parameter_table(getattr(kernel_module, 'parameters', []))
695
696        # Instantiate the current sasmodel
697        self.kernel_module = self.models[model_name]()
698
699        # Explicitly add scale and background with default values
700        self.addScaleToModel(self._model_model)
701        self.addBackgroundToModel(self._model_model)
702
703        # Update the QModel
704        new_rows = FittingUtilities.addParametersToModel(self.model_parameters, self.is2D)
705        for row in new_rows:
706            self._model_model.appendRow(row)
707        # Update the counter used for multishell display
708        self._last_model_row = self._model_model.rowCount()
709
710        FittingUtilities.addHeadersToModel(self._model_model)
711
712        # Add structure factor
713        if structure_factor is not None and structure_factor != "None":
714            structure_module = generate.load_kernel_module(structure_factor)
715            structure_parameters = modelinfo.make_parameter_table(getattr(structure_module, 'parameters', []))
716            new_rows = FittingUtilities.addSimpleParametersToModel(structure_parameters, self.is2D)
717            for row in new_rows:
718                self._model_model.appendRow(row)
719            # Update the counter used for multishell display
720            self._last_model_row = self._model_model.rowCount()
721        else:
722            self.addStructureFactor()
723
724        # Multishell models need additional treatment
725        self.addExtraShells()
726
727        # Add polydispersity to the model
728        self.setPolyModel()
729        # Add magnetic parameters to the model
730        self.setMagneticModel()
731
732        # Adjust the table cells width
733        self.lstParams.resizeColumnToContents(0)
734        self.lstParams.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
735
736        # Now we claim the model has been loaded
737        self.model_is_loaded = True
738
739        # Update Q Ranges
740        self.updateQRange()
741
742    def updateParamsFromModel(self, item):
743        """
744        Callback method for updating the sasmodel parameters with the GUI values
745        """
746        model_column = item.column()
747
748        if model_column == 0:
749            self.checkboxSelected(item)
750            return
751
752        model_row = item.row()
753        name_index = self._model_model.index(model_row, 0)
754
755        # Extract changed value. Assumes proper validation by QValidator/Delegate
756        value = float(item.text())
757        parameter_name = str(self._model_model.data(name_index).toPyObject()) # sld, background etc.
758        property_name = str(self._model_model.headerData(1, model_column).toPyObject()) # Value, min, max, etc.
759
760        self.kernel_module.params[parameter_name] = value
761
762        # min/max to be changed in self.kernel_module.details[parameter_name] = ['Ang', 0.0, inf]
763        # magnetic params in self.kernel_module.details['M0:parameter_name'] = value
764        # multishell params in self.kernel_module.details[??] = value
765
766        # Force the chart update when actual parameters changed
767        if model_column == 1:
768            self.onPlot()
769
770    def checkboxSelected(self, item):
771        # Assure we're dealing with checkboxes
772        if not item.isCheckable():
773            return
774        status = item.checkState()
775
776        def isChecked(row):
777            return self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked
778
779        def isCheckable(row):
780            return self._model_model.item(row, 0).isCheckable()
781
782        # If multiple rows selected - toggle all of them, filtering uncheckable
783        rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() if isCheckable(s.row())]
784
785        # Switch off signaling from the model to avoid recursion
786        self._model_model.blockSignals(True)
787        # Convert to proper indices and set requested enablement
788        items = [self._model_model.item(row, 0).setCheckState(status) for row in rows]
789        self._model_model.blockSignals(False)
790
791        # update the list of parameters to fit
792        self.parameters_to_fit = [str(self._model_model.item(row_index, 0).text())
793                                  for row_index in xrange(self._model_model.rowCount())
794                                  if isChecked(row_index)]
795
796    def nameForFittedData(self, name):
797        """
798        Generate name for the current fit
799        """
800        if self.is2D:
801            name += "2d"
802        name = "M%i [%s]" % (self.tab_id, name)
803        return name
804
805    def createNewIndex(self, fitted_data):
806        """
807        Create a model or theory index with passed Data1D/Data2D
808        """
809        if self.data_is_loaded:
810            if not fitted_data.name:
811                name = self.nameForFittedData(self.data.filename)
812                fitted_data.title = name
813                fitted_data.name = name
814                fitted_data.filename = name
815                fitted_data.symbol = "Line"
816            self.updateModelIndex(fitted_data)
817        else:
818            name = self.nameForFittedData(self.kernel_module.name)
819            fitted_data.title = name
820            fitted_data.name = name
821            fitted_data.filename = name
822            fitted_data.symbol = "Line"
823            self.createTheoryIndex(fitted_data)
824
825    def updateModelIndex(self, fitted_data):
826        """
827        Update a QStandardModelIndex containing model data
828        """
829        if fitted_data.name is None:
830            name = self.nameForFittedData(self.logic.data.filename)
831            fitted_data.title = name
832            fitted_data.name = name
833        else:
834            name = fitted_data.name
835        # Make this a line if no other defined
836        if hasattr(fitted_data, 'symbol') and fitted_data.symbol is None:
837            fitted_data.symbol = 'Line'
838        # Notify the GUI manager so it can update the main model in DataExplorer
839        GuiUtils.updateModelItemWithPlot(self._index, QtCore.QVariant(fitted_data), name)
840
841    def createTheoryIndex(self, fitted_data):
842        """
843        Create a QStandardModelIndex containing model data
844        """
845        if fitted_data.name is None:
846            name = self.nameForFittedData(self.kernel_module.name)
847            fitted_data.title = name
848            fitted_data.name = name
849            fitted_data.filename = name
850        else:
851            name = fitted_data.name
852        # Notify the GUI manager so it can create the theory model in DataExplorer
853        new_item = GuiUtils.createModelItemWithPlot(QtCore.QVariant(fitted_data), name=name)
854        self.communicate.updateTheoryFromPerspectiveSignal.emit(new_item)
855
856    def methodCalculateForData(self):
857        '''return the method for data calculation'''
858        return Calc1D if isinstance(self.data, Data1D) else Calc2D
859
860    def methodCompleteForData(self):
861        '''return the method for result parsin on calc complete '''
862        return self.complete1D if isinstance(self.data, Data1D) else self.complete2D
863
864    def calculateQGridForModel(self):
865        """
866        Prepare the fitting data object, based on current ModelModel
867        """
868        if self.kernel_module is None:
869            return
870        # Awful API to a backend method.
871        method = self.methodCalculateForData()(data=self.data,
872                              model=self.kernel_module,
873                              page_id=0,
874                              qmin=self.q_range_min,
875                              qmax=self.q_range_max,
876                              smearer=None,
877                              state=None,
878                              weight=None,
879                              fid=None,
880                              toggle_mode_on=False,
881                              completefn=None,
882                              update_chisqr=True,
883                              exception_handler=self.calcException,
884                              source=None)
885
886        calc_thread = threads.deferToThread(method.compute)
887        calc_thread.addCallback(self.methodCompleteForData())
888
889    def complete1D(self, return_data):
890        """
891        Plot the current 1D data
892        """
893        fitted_plot = self.logic.new1DPlot(return_data, self.tab_id)
894        self.calculateResiduals(fitted_plot)
895
896    def complete2D(self, return_data):
897        """
898        Plot the current 2D data
899        """
900        fitted_data = self.logic.new2DPlot(return_data)
901        self.calculateResiduals(fitted_data)
902
903    def calculateResiduals(self, fitted_data):
904        """
905        Calculate and print Chi2 and display chart of residuals
906        """
907        # Create a new index for holding data
908        fitted_data.symbol = "Line"
909        self.createNewIndex(fitted_data)
910        # Calculate difference between return_data and logic.data
911        chi2 = FittingUtilities.calculateChi2(fitted_data, self.logic.data)
912        # Update the control
913        chi2_repr = "---" if chi2 is None else GuiUtils.formatNumber(chi2, high=True)
914        self.lblChi2Value.setText(chi2_repr)
915
916        # Plot residuals if actual data
917        if self.data_is_loaded:
918            residuals_plot = FittingUtilities.plotResiduals(self.data, fitted_data)
919            residuals_plot.id = "Residual " + residuals_plot.id
920            self.createNewIndex(residuals_plot)
921            self.communicate.plotUpdateSignal.emit([residuals_plot])
922
923        self.communicate.plotUpdateSignal.emit([fitted_data])
924
925    def calcException(self, etype, value, tb):
926        """
927        Something horrible happened in the deferred.
928        """
929        logging.error("".join(traceback.format_exception(etype, value, tb)))
930
931    def setTableProperties(self, table):
932        """
933        Setting table properties
934        """
935        # Table properties
936        table.verticalHeader().setVisible(False)
937        table.setAlternatingRowColors(True)
938        table.setSizePolicy(QtGui.QSizePolicy.MinimumExpanding, QtGui.QSizePolicy.Expanding)
939        table.setSelectionBehavior(QtGui.QAbstractItemView.SelectRows)
940        table.resizeColumnsToContents()
941
942        # Header
943        header = table.horizontalHeader()
944        header.setResizeMode(QtGui.QHeaderView.ResizeToContents)
945
946        header.ResizeMode(QtGui.QHeaderView.Interactive)
947        # Resize column 0 and 6 to content
948        header.setResizeMode(0, QtGui.QHeaderView.ResizeToContents)
949        header.setResizeMode(6, QtGui.QHeaderView.ResizeToContents)
950
951    def setPolyModel(self):
952        """
953        Set polydispersity values
954        """
955        if not self.model_parameters:
956            return
957        self._poly_model.clear()
958        for row, param in enumerate(self.model_parameters.form_volume_parameters):
959            # Counters should not be included
960            if not param.polydisperse:
961                continue
962
963            # Potential multishell params
964            checked_list = ["Distribution of "+param.name, str(param.default),
965                            str(param.limits[0]), str(param.limits[1]),
966                            "35", "3", ""]
967            FittingUtilities.addCheckedListToModel(self._poly_model, checked_list)
968
969            #TODO: Need to find cleaner way to input functions
970            func = QtGui.QComboBox()
971            func.addItems(['rectangle', 'array', 'lognormal', 'gaussian', 'schulz',])
972            func_index = self.lstPoly.model().index(row, 6)
973            self.lstPoly.setIndexWidget(func_index, func)
974
975        FittingUtilities.addPolyHeadersToModel(self._poly_model)
976
977    def setMagneticModel(self):
978        """
979        Set magnetism values on model
980        """
981        if not self.model_parameters:
982            return
983        self._magnet_model.clear()
984        for param in self.model_parameters.call_parameters:
985            if param.type != "magnetic":
986                continue
987            checked_list = [param.name,
988                            str(param.default),
989                            str(param.limits[0]),
990                            str(param.limits[1]),
991                            param.units]
992            FittingUtilities.addCheckedListToModel(self._magnet_model, checked_list)
993
994        FittingUtilities.addHeadersToModel(self._magnet_model)
995
996    def addStructureFactor(self):
997        """
998        Add structure factors to the list of parameters
999        """
1000        if self.kernel_module.is_form_factor:
1001            self.enableStructureCombo()
1002        else:
1003            self.disableStructureCombo()
1004
1005    def addExtraShells(self):
1006        """
1007        Add a combobox for multiple shell display
1008        """
1009        param_name, param_length = FittingUtilities.getMultiplicity(self.model_parameters)
1010
1011        if param_length == 0:
1012            return
1013
1014        # cell 1: variable name
1015        item1 = QtGui.QStandardItem(param_name)
1016
1017        func = QtGui.QComboBox()
1018        # Available range of shells displayed in the combobox
1019        func.addItems([str(i) for i in xrange(param_length+1)])
1020
1021        # Respond to index change
1022        func.currentIndexChanged.connect(self.modifyShellsInList)
1023
1024        # cell 2: combobox
1025        item2 = QtGui.QStandardItem()
1026        self._model_model.appendRow([item1, item2])
1027
1028        # Beautify the row:  span columns 2-4
1029        shell_row = self._model_model.rowCount()
1030        shell_index = self._model_model.index(shell_row-1, 1)
1031
1032        self.lstParams.setIndexWidget(shell_index, func)
1033        self._last_model_row = self._model_model.rowCount()
1034
1035        # Set the index to the state-kept value
1036        func.setCurrentIndex(self.current_shell_displayed
1037                             if self.current_shell_displayed < func.count() else 0)
1038
1039    def modifyShellsInList(self, index):
1040        """
1041        Add/remove additional multishell parameters
1042        """
1043        # Find row location of the combobox
1044        last_row = self._last_model_row
1045        remove_rows = self._model_model.rowCount() - last_row
1046
1047        if remove_rows > 1:
1048            self._model_model.removeRows(last_row, remove_rows)
1049
1050        FittingUtilities.addShellsToModel(self.model_parameters, self._model_model, index)
1051        self.current_shell_displayed = index
1052
Note: See TracBrowser for help on using the repository browser.