Ignore:
Timestamp:
Apr 5, 2017 10:44:27 AM (7 years ago)
Author:
Piotr Rozyczko <rozyczko@…>
Branches:
ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc
Children:
454670d
Parents:
116260a
Message:

Fitting connected. Initial prototype

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/qtgui/Perspectives/Fitting/FittingWidget.py

    r9f25bce rf182f93  
    1515from sasmodels import modelinfo 
    1616from sasmodels.sasview_model import load_standard_models 
     17from sas.sascalc.fit.BumpsFitting import BumpsFit as Fit 
     18from sas.sasgui.perspectives.fitting.fit_thread import FitThread 
    1719 
    1820from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller 
     
    6264        # Dictionary of {model name: model class} for the current category 
    6365        self.models = {} 
     66        # Parameters to fit 
     67        self.parameters_to_fit = None 
    6468 
    6569        # Which tab is this widget displayed in? 
     
    6872        # Which shell is being currently displayed? 
    6973        self.current_shell_displayed = 0 
     74        self.has_error_column = False 
    7075 
    7176        # Range parameters 
     
    256261        self.cbStructureFactor.setCurrentIndex(0) 
    257262 
     263        # Reset parameters to fit 
     264        self.parameters_to_fit = None 
     265 
    258266        # SasModel -> QModel 
    259267        self.SASModelToQModel(model) 
     
    350358        Perform fitting on the current data 
    351359        """ 
    352         # TODO: everything here 
    353         #self.calculate1DForModel() 
    354         #calc_fit = FitThread(handler=handler, 
    355         #                     fn=fitter_list, 
    356         #                     batch_inputs=batch_inputs, 
    357         #                     batch_outputs=batch_outputs, 
    358         #                     page_id=list_page_id, 
    359         #                     updatefn=handler.update_fit, 
    360         #                     completefn=self._fit_completed) 
    361  
     360        fitter = Fit() 
     361 
     362        # Data going in 
     363        data = self.logic.data 
     364        model = self.kernel_module 
     365        qmin = self.q_range_min 
     366        qmax = self.q_range_max 
     367        params_to_fit = self.parameters_to_fit 
     368 
     369        # These should be updating somehow? 
     370        fit_id = 0 
     371        constraints = [] 
     372        smearer = None 
     373        page_id = [210] 
     374        handler = None 
     375        batch_inputs = {} 
     376        batch_outputs = {} 
     377        list_page_id = [page_id] 
     378        #--------------------------------- 
     379 
     380        # Parameterize the fitter 
     381        fitter.set_model(model, fit_id, params_to_fit, data=data, 
     382                         constraints=constraints) 
     383        fitter.set_data(data=data, id=fit_id, smearer=smearer, qmin=qmin, 
     384                        qmax=qmax) 
     385        fitter.select_problem_for_fit(id=fit_id, value=1) 
     386 
     387        fitter.fitter_id = page_id 
     388 
     389        # Create the fitting thread, based on the fitter 
     390        calc_fit = FitThread(handler=handler, 
     391                             fn=[fitter], 
     392                             batch_inputs=batch_inputs, 
     393                             batch_outputs=batch_outputs, 
     394                             page_id=list_page_id, 
     395                             updatefn=self.updateFit, 
     396                             completefn=None) 
     397 
     398        # start the trhrhread 
     399        calc_thread = threads.deferToThread(calc_fit.compute) 
     400        calc_thread.addCallback(self.fitComplete) 
     401 
     402        #disable the Fit button 
     403        self.cmdFit.setText("Calculating...") 
     404        self.cmdFit.setEnabled(False) 
     405 
     406    def updateFit(self): 
     407        """ 
     408        """ 
     409        print "UPDATE FIT" 
    362410        pass 
     411 
     412    def fitComplete(self, result): 
     413        """ 
     414        Receive and display fitting results 
     415 
     416        "result" is a tuple of actual result list and the fit time in seconds 
     417        """ 
     418        #re-enable the Fit button 
     419        self.cmdFit.setText("Fit") 
     420        self.cmdFit.setEnabled(True) 
     421        res_list = result[0] 
     422        res = res_list[0] 
     423        if res.fitness is None or \ 
     424            not numpy.isfinite(res.fitness) or \ 
     425            numpy.any(res.pvec == None) or \ 
     426            not numpy.all(numpy.isfinite(res.pvec)): 
     427            msg = "Fitting did not converge!!!" 
     428            logging.error(msg) 
     429            return 
     430 
     431        elapsed = result[1] 
     432        msg = "Fitting completed successfully in: %s s.\n" % GuiUtils.formatNumber(elapsed) 
     433 
     434        self.communicate.statusBarUpdateSignal.emit(msg) 
     435 
     436        fitness = res.fitness 
     437        param_list = res.param_list 
     438        param_values = res.pvec 
     439        param_stderr = res.stderr 
     440        from itertools import izip 
     441        # TODO: add errors to the dict so they can propagate to the view 
     442        params_and_errors = zip(param_values, param_stderr) 
     443        param_dict = dict(izip(param_list, params_and_errors)) 
     444 
     445        # Dictionary of fitted parameter: value, error 
     446        # e.g. param_dic = {"sld":(1.703, 0.0034), "length":(33.455, -0.0983)} 
     447        self.updateModelFromList(param_dict) 
     448 
     449        # Read only value - we can get away by just printing it here 
     450        chi2_repr = GuiUtils.formatNumber(fitness, high=True) 
     451        self.lblChi2Value.setText(chi2_repr) 
     452 
     453        pass 
     454 
     455    def iterateOverModel(self, func): 
     456        """ 
     457        Take func and throw it inside the model row loop 
     458        """ 
     459        #assert isinstance(func, function) 
     460        for row_i in xrange(self._model_model.rowCount()): 
     461            func(row_i) 
     462 
     463    def updateModelFromList(self, param_dict): 
     464        """ 
     465        Update the model with new parameters, create the errors column 
     466        """ 
     467        assert isinstance(param_dict, dict) 
     468        if not dict: 
     469            return 
     470 
     471        def updateValues(row_i): 
     472            # Utility function for main model update 
     473            param_name = str(self._model_model.item(row_i, 0).text()) 
     474            if param_name not in param_dict.keys(): 
     475                return 
     476            # modify the param value 
     477            self._model_model.item(row_i, 1).setText(str(param_dict[param_name][0])) 
     478            if self.has_error_column: 
     479                self._model_model.item(row_i, 2).setText(str(param_dict[param_name][1])) 
     480 
     481        def createColumn(row_i): 
     482            # Utility function for error column update 
     483            item = QtGui.QStandardItem() 
     484            for param_name in param_dict.keys(): 
     485                if str(self._model_model.item(row_i, 0).text()) != param_name: 
     486                    continue 
     487                error_repr = GuiUtils.formatNumber(param_dict[param_name][1], high=True) 
     488                item.setText(error_repr) 
     489            error_column.append(item) 
     490 
     491        self.iterateOverModel(updateValues) 
     492 
     493        if self.has_error_column: 
     494            return 
     495 
     496        error_column = [] 
     497        self.iterateOverModel(createColumn) 
     498 
     499        self.has_error_column = True 
     500        self._model_model.insertColumn(2, error_column) 
     501        FittingUtilities.addErrorHeadersToModel(self._model_model) 
     502 
    363503 
    364504    def onPlot(self): 
     
    370510        self.calculateQGridForModel() 
    371511 
    372     #def onNpts(self, text): 
    373512    def onNpts(self): 
    374513        """ 
     
    387526            self.onPlot() 
    388527 
    389     #def onMinRange(self, text): 
    390528    def onMinRange(self): 
    391529        """ 
     
    405543            self.onPlot() 
    406544 
    407     #def onMaxRange(self, text): 
    408545    def onMaxRange(self): 
    409546        """ 
     
    537674            structure_parameters = modelinfo.make_parameter_table(getattr(structure_module, 'parameters', [])) 
    538675            FittingUtilities.addSimpleParametersToModel(structure_parameters, self._model_model) 
     676            # Set the error column width to 0 
     677            self.lstParams.setColumnWidth(2, 20) 
    539678            # Update the counter used for multishell display 
    540679            self._last_model_row = self._model_model.rowCount() 
     
    565704        """ 
    566705        model_column = item.column() 
     706 
     707        if model_column == 0: 
     708            self.checkboxSelected(item) 
     709            return 
     710 
    567711        model_row = item.row() 
    568712        name_index = self._model_model.index(model_row, 0) 
    569  
    570         if model_column == 0: 
    571             # Assure we're dealing with checkboxes 
    572             if not item.isCheckable(): 
    573                 return 
    574             status = item.checkState() 
    575             # If multiple rows selected - toggle all of them 
    576             rows = [s.row() for s in self.lstParams.selectionModel().selectedRows()] 
    577  
    578             # Switch off signaling from the model to avoid multiple calls 
    579             self._model_model.blockSignals(True) 
    580             # Convert to proper indices and set requested enablement 
    581             items = [self._model_model.item(row, 0).setCheckState(status) for row in rows] 
    582             self._model_model.blockSignals(False) 
    583             return 
    584713 
    585714        # Extract changed value. Assumes proper validation by QValidator/Delegate 
     
    588717        property_name = str(self._model_model.headerData(1, model_column).toPyObject()) # Value, min, max, etc. 
    589718 
    590         # print "%s(%s) => %d" % (parameter_name, property_name, value) 
    591719        self.kernel_module.params[parameter_name] = value 
    592720 
    593721        # min/max to be changed in self.kernel_module.details[parameter_name] = ['Ang', 0.0, inf] 
    594  
    595722        # magnetic params in self.kernel_module.details['M0:parameter_name'] = value 
    596723        # multishell params in self.kernel_module.details[??] = value 
     
    598725        # Force the chart update 
    599726        self.onPlot() 
     727 
     728    def checkboxSelected(self, item): 
     729        # Assure we're dealing with checkboxes 
     730        if not item.isCheckable(): 
     731            return 
     732        status = item.checkState() 
     733 
     734        def isChecked(row): 
     735            return self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked 
     736 
     737        def isCheckable(row): 
     738            return self._model_model.item(row, 0).isCheckable() 
     739 
     740        # If multiple rows selected - toggle all of them, filtering uncheckable 
     741        rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() if isCheckable(s.row())] 
     742 
     743        # Switch off signaling from the model to avoid recursion 
     744        self._model_model.blockSignals(True) 
     745        # Convert to proper indices and set requested enablement 
     746        items = [self._model_model.item(row, 0).setCheckState(status) for row in rows] 
     747        self._model_model.blockSignals(False) 
     748 
     749        # update the list of parameters to fit 
     750        self.parameters_to_fit = [str(self._model_model.item(row_index, 0).text()) 
     751                                  for row_index in xrange(self._model_model.rowCount()) 
     752                                  if isChecked(row_index)] 
    600753 
    601754    def nameForFittedData(self, name): 
     
    714867        chi2 = FittingUtilities.calculateChi2(fitted_data, self.logic.data) 
    715868        # Update the control 
    716         self.lblChi2Value.setText(GuiUtils.formatNumber(chi2, high=True)) 
     869        chi2_repr = "---" if chi2 is None else GuiUtils.formatNumber(chi2, high=True) 
     870        #self.lblChi2Value.setText(GuiUtils.formatNumber(chi2, high=True)) 
     871        self.lblChi2Value.setText(chi2_repr) 
    717872 
    718873        # Plot residuals if actual data 
Note: See TracChangeset for help on using the changeset viewer.