Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/sasgui/perspectives/fitting/basepage.py

    ree4b3cb r6c382da  
    1717from wx.lib.scrolledpanel import ScrolledPanel 
    1818 
    19 import sasmodels.sasview_model 
     19from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS 
     20 
    2021from sas.sasgui.guiframe.panel_base import PanelBase 
    21 from sas.sasgui.guiframe.utils import format_number, check_float, IdList 
     22from sas.sasgui.guiframe.utils import format_number, check_float, IdList, check_int 
    2223from sas.sasgui.guiframe.events import PanelOnFocusEvent 
    2324from sas.sasgui.guiframe.events import StatusEvent 
     
    626627        self.disp_help_bt.Bind(wx.EVT_BUTTON, self.on_pd_help_clicked, 
    627628                               id=self.disp_help_bt.GetId()) 
    628         self.disp_help_bt.SetToolTipString("Helps for Polydispersion.") 
     629        self.disp_help_bt.SetToolTipString("Help for polydispersion.") 
    629630 
    630631        self.Bind(wx.EVT_RADIOBUTTON, self._set_dipers_Param, 
     
    932933        if len(self._disp_obj_dict) > 0: 
    933934            for k, v in self._disp_obj_dict.iteritems(): 
    934                 self.state._disp_obj_dict[k] = v 
     935                self.state._disp_obj_dict[k] = v.type 
    935936 
    936937            self.state.values = copy.deepcopy(self.values) 
     
    10091010            if len(self._disp_obj_dict) > 0: 
    10101011                for k, v in self._disp_obj_dict.iteritems(): 
    1011                     self.state._disp_obj_dict[k] = v 
     1012                    self.state._disp_obj_dict[k] = v.type 
    10121013 
    10131014            self.state.values = copy.deepcopy(self.values) 
     
    11231124                                                    state.disp_cb_dict[item]) 
    11241125                        # Create the dispersion objects 
    1125                         from sas.models.dispersion_models import ArrayDispersion 
    1126                         disp_model = ArrayDispersion() 
     1126                        disp_model = POLYDISPERSITY_MODELS['array']() 
    11271127                        if hasattr(state, "values") and \ 
    11281128                                 self.disp_cb_dict[item].GetValue() == True: 
     
    13791379        self.weights = copy.deepcopy(state.weights) 
    13801380 
    1381         for key, disp in state._disp_obj_dict.iteritems(): 
    1382             # From saved file, disp_model can not be sent in model obj. 
    1383             # it will be sent as a string here, then converted to model object. 
    1384             if disp.__class__.__name__ == 'str': 
    1385                 disp_model = None 
    1386                 com_str = "from sasmodels.weights " 
    1387                 com_str += "import %s as disp_func \ndisp_model = disp_func()" 
    1388                 exec com_str % disp 
    1389             else: 
    1390                 disp_model = disp 
     1381        for key, disp_type in state._disp_obj_dict.iteritems(): 
     1382            #disp_model = disp 
     1383            disp_model = POLYDISPERSITY_MODELS[disp_type]() 
    13911384            self._disp_obj_dict[key] = disp_model 
    13921385            param_name = key.split('.')[0] 
     
    22812274                continue 
    22822275 
    2283             name = str(item[1]) 
    2284             if name.endswith(".npts") or name.endswith(".nsigmas"): 
     2276            value_ctrl = item[2] 
     2277            if not value_ctrl.IsEnabled(): 
     2278                # ArrayDispersion disables PD, Min, Max, Npts, Nsigs 
    22852279                continue 
    22862280 
    2287             # Check that min, max and value are floats 
    2288             value_ctrl, min_ctrl, max_ctrl = item[2], item[5], item[6] 
    2289             min_str = min_ctrl.GetValue().strip() 
    2290             max_str = max_ctrl.GetValue().strip() 
     2281            name = item[1] 
    22912282            value_str = value_ctrl.GetValue().strip() 
    2292             validity = check_float(value_ctrl) 
    2293             if min_str != "": 
    2294                 validity = validity and check_float(min_ctrl) 
    2295             if max_str != "": 
    2296                 validity = validity and check_float(max_ctrl) 
    2297             if not validity: 
    2298                 continue 
    2299  
    2300             # Check that min is less than max 
    2301             low = -numpy.inf if min_str == "" else float(min_str) 
    2302             high = numpy.inf if max_str == "" else float(max_str) 
    2303             if high < low: 
    2304                 min_ctrl.SetBackgroundColour("pink") 
    2305                 min_ctrl.Refresh() 
    2306                 max_ctrl.SetBackgroundColour("pink") 
    2307                 max_ctrl.Refresh() 
    2308                 #msg = "Invalid fit range for %s: min must be smaller than max"%name 
    2309                 #wx.PostEvent(self._manager.parent, StatusEvent(status=msg)) 
    2310                 continue 
    2311  
    2312             # Force value between min and max 
    2313             value = float(value_str) 
    2314             if value < low: 
    2315                 value = low 
    2316                 value_ctrl.SetValue(format_number(value)) 
    2317             elif value > high: 
    2318                 value = high 
    2319                 value_ctrl.SetValue(format_number(value)) 
     2283            if name.endswith(".npts"): 
     2284                validity = check_int(value_ctrl) 
     2285                if not validity: 
     2286                    continue 
     2287                value = int(value_str) 
     2288 
     2289            elif name.endswith(".nsigmas"): 
     2290                validity = check_float(value_ctrl) 
     2291                if not validity: 
     2292                    continue 
     2293                value = float(value_str) 
     2294 
     2295            else:  # value or polydispersity 
     2296 
     2297                # Check that min, max and value are floats 
     2298                min_ctrl, max_ctrl = item[5], item[6] 
     2299                min_str = min_ctrl.GetValue().strip() 
     2300                max_str = max_ctrl.GetValue().strip() 
     2301                validity = check_float(value_ctrl) 
     2302                if min_str != "": 
     2303                    validity = validity and check_float(min_ctrl) 
     2304                if max_str != "": 
     2305                    validity = validity and check_float(max_ctrl) 
     2306                if not validity: 
     2307                    continue 
     2308 
     2309                # Check that min is less than max 
     2310                low = -numpy.inf if min_str == "" else float(min_str) 
     2311                high = numpy.inf if max_str == "" else float(max_str) 
     2312                if high < low: 
     2313                    min_ctrl.SetBackgroundColour("pink") 
     2314                    min_ctrl.Refresh() 
     2315                    max_ctrl.SetBackgroundColour("pink") 
     2316                    max_ctrl.Refresh() 
     2317                    #msg = "Invalid fit range for %s: min must be smaller than max"%name 
     2318                    #wx.PostEvent(self._manager.parent, StatusEvent(status=msg)) 
     2319                    continue 
     2320 
     2321                # Force value between min and max 
     2322                value = float(value_str) 
     2323                if value < low: 
     2324                    value = low 
     2325                    value_ctrl.SetValue(format_number(value)) 
     2326                elif value > high: 
     2327                    value = high 
     2328                    value_ctrl.SetValue(format_number(value)) 
     2329 
     2330                if name not in self.model.details.keys(): 
     2331                    self.model.details[name] = ["", None, None] 
     2332                old_low, old_high = self.model.details[name][1:3] 
     2333                if old_low != low or old_high != high: 
     2334                    # The configuration has changed but it won't change the 
     2335                    # computed curve so no need to set is_modified to True 
     2336                    #is_modified = True 
     2337                    self.model.details[name][1:3] = low, high 
    23202338 
    23212339            # Update value in model if it has changed 
     
    23232341                self.model.setParam(name, value) 
    23242342                is_modified = True 
    2325  
    2326             if name not in self.model.details.keys(): 
    2327                 self.model.details[name] = ["", None, None] 
    2328             old_low, old_high = self.model.details[name][1:3] 
    2329             if old_low != low or old_high != high: 
    2330                 # The configuration has changed but it won't change the 
    2331                 # computed curve so no need to set is_modified to True 
    2332                 #is_modified = True 
    2333                 self.model.details[name][1:3] = low, high 
    23342343 
    23352344        return is_modified 
     
    25042513                self._disp_obj_dict[name1] = disp_model 
    25052514                self.model.set_dispersion(param_name, disp_model) 
    2506                 self.state._disp_obj_dict[name1] = disp_model 
     2515                self.state._disp_obj_dict[name1] = disp_model.type 
    25072516 
    25082517                value1 = str(format_number(self.model.getParam(name1), True)) 
     
    25272536                        item[0].Enable() 
    25282537                        item[2].Enable() 
     2538                        item[3].Show(True) 
     2539                        item[4].Show(True) 
    25292540                        item[5].Enable() 
    25302541                        item[6].Enable() 
     
    26192630        self._disp_obj_dict[name] = disp 
    26202631        self.model.set_dispersion(name.split('.')[0], disp) 
    2621         self.state._disp_obj_dict[name] = disp 
     2632        self.state._disp_obj_dict[name] = disp.type 
    26222633        self.values[name] = values 
    26232634        self.weights[name] = weights 
     
    26872698        :param disp_function: dispersion distr. function 
    26882699        """ 
    2689         # List of the poly_model name in the combobox 
    2690         list = ["RectangleDispersion", "ArrayDispersion", 
    2691                 "LogNormalDispersion", "GaussianDispersion", 
    2692                 "SchulzDispersion"] 
    2693  
    26942700        # Find the selection 
    2695         try: 
    2696             selection = list.index(disp_func.__class__.__name__) 
    2697             return selection 
    2698         except: 
    2699             return 3 
     2701        if disp_func is not None: 
     2702            try: 
     2703                return POLYDISPERSITY_MODELS.values().index(disp_func.__class__) 
     2704            except ValueError: 
     2705                pass  # Fall through to default class 
     2706        return POLYDISPERSITY_MODELS.keys().index('gaussian') 
    27002707 
    27012708    def on_reset_clicked(self, event): 
     
    32843291                    pd = content[name][1] 
    32853292                    if name.count('.') > 0: 
     3293                        # If this is parameter.width, then pd may be a floating 
     3294                        # point value or it may be an array distribution. 
     3295                        # Nothing to do for parameter.npts or parameter.nsigmas. 
    32863296                        try: 
    32873297                            float(pd) 
    3288                         except: 
     3298                            if name.endswith('.npts'): 
     3299                                pd = int(pd) 
     3300                        except Exception: 
    32893301                            #continue 
    32903302                            if not pd and pd != '': 
     
    32943306                        # Only array func has pd == '' case. 
    32953307                        item[2].Enable(False) 
     3308                    else: 
     3309                        item[2].Enable(True) 
    32963310                    if item[2].__class__.__name__ == "ComboBox": 
    32973311                        if content[name][1] in self.model.fun_list: 
     
    33203334                        pd = value[0] 
    33213335                        if name.count('.') > 0: 
     3336                            # If this is parameter.width, then pd may be a floating 
     3337                            # point value or it may be an array distribution. 
     3338                            # Nothing to do for parameter.npts or parameter.nsigmas. 
    33223339                            try: 
    33233340                                pd = float(pd) 
     3341                                if name.endswith('.npts'): 
     3342                                    pd = int(pd) 
    33243343                            except: 
    33253344                                #continue 
     
    33303349                            # Only array func has pd == '' case. 
    33313350                            item[2].Enable(False) 
     3351                        else: 
     3352                            item[2].Enable(True) 
    33323353                        if item[2].__class__.__name__ == "ComboBox": 
    33333354                            if value[0] in self.model.fun_list: 
     
    33493370        Helps get paste for poly function 
    33503371 
    3351         :param item: Gui param items 
    3352         :param value: the values for parameter ctrols 
    3353         """ 
    3354         is_array = False 
    3355         if len(value[1]) > 0: 
    3356             # Only for dispersion func.s 
    3357             try: 
    3358                 item[7].SetValue(value[1]) 
    3359                 selection = item[7].GetCurrentSelection() 
    3360                 name = item[7].Name 
    3361                 param_name = name.split('.')[0] 
    3362                 dispersity = item[7].GetClientData(selection) 
    3363                 disp_model = dispersity() 
    3364                 # Only for array disp 
    3365                 try: 
    3366                     pd_vals = numpy.array(value[2]) 
    3367                     pd_weights = numpy.array(value[3]) 
    3368                     if len(pd_vals) > 0 and len(pd_vals) > 0: 
    3369                         if len(pd_vals) == len(pd_weights): 
    3370                             self._set_disp_array_cb(item=item) 
    3371                             self._set_array_disp_model(name=name, 
    3372                                                        disp=disp_model, 
    3373                                                        values=pd_vals, 
    3374                                                        weights=pd_weights) 
    3375                             is_array = True 
    3376                 except Exception: 
    3377                     logging.error(traceback.format_exc()) 
    3378                 if not is_array: 
    3379                     self._disp_obj_dict[name] = disp_model 
    3380                     self.model.set_dispersion(name, 
    3381                                               disp_model) 
    3382                     self.state._disp_obj_dict[name] = \ 
    3383                                               disp_model 
    3384                     self.model.set_dispersion(param_name, disp_model) 
    3385                     self.state.values = self.values 
    3386                     self.state.weights = self.weights 
    3387                     self.model._persistency_dict[param_name] = \ 
    3388                                             [self.state.values, 
    3389                                              self.state.weights] 
    3390  
    3391             except Exception: 
    3392                 logging.error(traceback.format_exc()) 
    3393                 print "Error in BasePage._paste_poly_help: %s" % \ 
    3394                                         sys.exc_info()[1] 
    3395  
    3396     def _set_disp_array_cb(self, item): 
     3372        *item* is the parameter name 
     3373 
     3374        *value* depends on which parameter is being processed, and whether it 
     3375        has array polydispersity. 
     3376 
     3377        For parameters without array polydispersity: 
     3378 
     3379            parameter => ['FLOAT', ''] 
     3380            parameter.width => ['FLOAT', 'DISTRIBUTION', ''] 
     3381            parameter.npts => ['FLOAT', ''] 
     3382            parameter.nsigmas => ['FLOAT', ''] 
     3383 
     3384        For parameters with array polydispersity: 
     3385 
     3386            parameter => ['FLOAT', ''] 
     3387            parameter.width => ['FILENAME', 'array', [x1, ...], [w1, ...]] 
     3388            parameter.npts => ['FLOAT', ''] 
     3389            parameter.nsigmas => ['FLOAT', ''] 
     3390        """ 
     3391        # Do nothing if not setting polydispersity 
     3392        if len(value[1]) == 0: 
     3393            return 
     3394 
     3395        try: 
     3396            name = item[7].Name 
     3397            param_name = name.split('.')[0] 
     3398            item[7].SetValue(value[1]) 
     3399            selection = item[7].GetCurrentSelection() 
     3400            dispersity = item[7].GetClientData(selection) 
     3401            disp_model = dispersity() 
     3402 
     3403            if value[1] == 'array': 
     3404                pd_vals = numpy.array(value[2]) 
     3405                pd_weights = numpy.array(value[3]) 
     3406                if len(pd_vals) == 0 or len(pd_vals) != len(pd_weights): 
     3407                    msg = ("bad array distribution parameters for %s" 
     3408                           % param_name) 
     3409                    raise ValueError(msg) 
     3410                self._set_disp_cb(True, item=item) 
     3411                self._set_array_disp_model(name=name, 
     3412                                           disp=disp_model, 
     3413                                           values=pd_vals, 
     3414                                           weights=pd_weights) 
     3415            else: 
     3416                self._set_disp_cb(False, item=item) 
     3417                self._disp_obj_dict[name] = disp_model 
     3418                self.model.set_dispersion(param_name, disp_model) 
     3419                self.state._disp_obj_dict[name] = disp_model.type 
     3420                # TODO: It's not an array, why update values and weights? 
     3421                self.model._persistency_dict[param_name] = \ 
     3422                    [self.values, self.weights] 
     3423                self.state.values = self.values 
     3424                self.state.weights = self.weights 
     3425 
     3426        except Exception: 
     3427            logging.error(traceback.format_exc()) 
     3428            print "Error in BasePage._paste_poly_help: %s" % \ 
     3429                                    sys.exc_info()[1] 
     3430 
     3431    def _set_disp_cb(self, isarray, item): 
    33973432        """ 
    33983433        Set cb for array disp 
    33993434        """ 
    3400         item[0].SetValue(False) 
    3401         item[0].Enable(False) 
    3402         item[2].Enable(False) 
    3403         item[3].Show(False) 
    3404         item[4].Show(False) 
    3405         item[5].SetValue('') 
    3406         item[5].Enable(False) 
    3407         item[6].SetValue('') 
    3408         item[6].Enable(False) 
     3435        if isarray: 
     3436            item[0].SetValue(False) 
     3437            item[0].Enable(False) 
     3438            item[2].Enable(False) 
     3439            item[3].Show(False) 
     3440            item[4].Show(False) 
     3441            item[5].SetValue('') 
     3442            item[5].Enable(False) 
     3443            item[6].SetValue('') 
     3444            item[6].Enable(False) 
     3445        else: 
     3446            item[0].Enable() 
     3447            item[2].Enable() 
     3448            item[3].Show(True) 
     3449            item[4].Show(True) 
     3450            item[5].Enable() 
     3451            item[6].Enable() 
    34093452 
    34103453    def update_pinhole_smear(self): 
Note: See TracChangeset for help on using the changeset viewer.