Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/sasgui/perspectives/fitting/basepage.py

    r6c382da ree4b3cb  
    1717from wx.lib.scrolledpanel import ScrolledPanel 
    1818 
    19 from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS 
    20  
     19import sasmodels.sasview_model 
    2120from sas.sasgui.guiframe.panel_base import PanelBase 
    22 from sas.sasgui.guiframe.utils import format_number, check_float, IdList, check_int 
     21from sas.sasgui.guiframe.utils import format_number, check_float, IdList 
    2322from sas.sasgui.guiframe.events import PanelOnFocusEvent 
    2423from sas.sasgui.guiframe.events import StatusEvent 
     
    627626        self.disp_help_bt.Bind(wx.EVT_BUTTON, self.on_pd_help_clicked, 
    628627                               id=self.disp_help_bt.GetId()) 
    629         self.disp_help_bt.SetToolTipString("Help for polydispersion.") 
     628        self.disp_help_bt.SetToolTipString("Helps for Polydispersion.") 
    630629 
    631630        self.Bind(wx.EVT_RADIOBUTTON, self._set_dipers_Param, 
     
    933932        if len(self._disp_obj_dict) > 0: 
    934933            for k, v in self._disp_obj_dict.iteritems(): 
    935                 self.state._disp_obj_dict[k] = v.type 
     934                self.state._disp_obj_dict[k] = v 
    936935 
    937936            self.state.values = copy.deepcopy(self.values) 
     
    10101009            if len(self._disp_obj_dict) > 0: 
    10111010                for k, v in self._disp_obj_dict.iteritems(): 
    1012                     self.state._disp_obj_dict[k] = v.type 
     1011                    self.state._disp_obj_dict[k] = v 
    10131012 
    10141013            self.state.values = copy.deepcopy(self.values) 
     
    11241123                                                    state.disp_cb_dict[item]) 
    11251124                        # Create the dispersion objects 
    1126                         disp_model = POLYDISPERSITY_MODELS['array']() 
     1125                        from sas.models.dispersion_models import ArrayDispersion 
     1126                        disp_model = ArrayDispersion() 
    11271127                        if hasattr(state, "values") and \ 
    11281128                                 self.disp_cb_dict[item].GetValue() == True: 
     
    13791379        self.weights = copy.deepcopy(state.weights) 
    13801380 
    1381         for key, disp_type in state._disp_obj_dict.iteritems(): 
    1382             #disp_model = disp 
    1383             disp_model = POLYDISPERSITY_MODELS[disp_type]() 
     1381        for key, disp in state._disp_obj_dict.iteritems(): 
     1382            # From saved file, disp_model can not be sent in model obj. 
     1383            # it will be sent as a string here, then converted to model object. 
     1384            if disp.__class__.__name__ == 'str': 
     1385                disp_model = None 
     1386                com_str = "from sasmodels.weights " 
     1387                com_str += "import %s as disp_func \ndisp_model = disp_func()" 
     1388                exec com_str % disp 
     1389            else: 
     1390                disp_model = disp 
    13841391            self._disp_obj_dict[key] = disp_model 
    13851392            param_name = key.split('.')[0] 
     
    22742281                continue 
    22752282 
    2276             value_ctrl = item[2] 
    2277             if not value_ctrl.IsEnabled(): 
    2278                 # ArrayDispersion disables PD, Min, Max, Npts, Nsigs 
     2283            name = str(item[1]) 
     2284            if name.endswith(".npts") or name.endswith(".nsigmas"): 
    22792285                continue 
    22802286 
    2281             name = item[1] 
     2287            # Check that min, max and value are floats 
     2288            value_ctrl, min_ctrl, max_ctrl = item[2], item[5], item[6] 
     2289            min_str = min_ctrl.GetValue().strip() 
     2290            max_str = max_ctrl.GetValue().strip() 
    22822291            value_str = value_ctrl.GetValue().strip() 
    2283             if name.endswith(".npts"): 
    2284                 validity = check_int(value_ctrl) 
    2285                 if not validity: 
    2286                     continue 
    2287                 value = int(value_str) 
    2288  
    2289             elif name.endswith(".nsigmas"): 
    2290                 validity = check_float(value_ctrl) 
    2291                 if not validity: 
    2292                     continue 
    2293                 value = float(value_str) 
    2294  
    2295             else:  # value or polydispersity 
    2296  
    2297                 # Check that min, max and value are floats 
    2298                 min_ctrl, max_ctrl = item[5], item[6] 
    2299                 min_str = min_ctrl.GetValue().strip() 
    2300                 max_str = max_ctrl.GetValue().strip() 
    2301                 validity = check_float(value_ctrl) 
    2302                 if min_str != "": 
    2303                     validity = validity and check_float(min_ctrl) 
    2304                 if max_str != "": 
    2305                     validity = validity and check_float(max_ctrl) 
    2306                 if not validity: 
    2307                     continue 
    2308  
    2309                 # Check that min is less than max 
    2310                 low = -numpy.inf if min_str == "" else float(min_str) 
    2311                 high = numpy.inf if max_str == "" else float(max_str) 
    2312                 if high < low: 
    2313                     min_ctrl.SetBackgroundColour("pink") 
    2314                     min_ctrl.Refresh() 
    2315                     max_ctrl.SetBackgroundColour("pink") 
    2316                     max_ctrl.Refresh() 
    2317                     #msg = "Invalid fit range for %s: min must be smaller than max"%name 
    2318                     #wx.PostEvent(self._manager.parent, StatusEvent(status=msg)) 
    2319                     continue 
    2320  
    2321                 # Force value between min and max 
    2322                 value = float(value_str) 
    2323                 if value < low: 
    2324                     value = low 
    2325                     value_ctrl.SetValue(format_number(value)) 
    2326                 elif value > high: 
    2327                     value = high 
    2328                     value_ctrl.SetValue(format_number(value)) 
    2329  
    2330                 if name not in self.model.details.keys(): 
    2331                     self.model.details[name] = ["", None, None] 
    2332                 old_low, old_high = self.model.details[name][1:3] 
    2333                 if old_low != low or old_high != high: 
    2334                     # The configuration has changed but it won't change the 
    2335                     # computed curve so no need to set is_modified to True 
    2336                     #is_modified = True 
    2337                     self.model.details[name][1:3] = low, high 
     2292            validity = check_float(value_ctrl) 
     2293            if min_str != "": 
     2294                validity = validity and check_float(min_ctrl) 
     2295            if max_str != "": 
     2296                validity = validity and check_float(max_ctrl) 
     2297            if not validity: 
     2298                continue 
     2299 
     2300            # Check that min is less than max 
     2301            low = -numpy.inf if min_str == "" else float(min_str) 
     2302            high = numpy.inf if max_str == "" else float(max_str) 
     2303            if high < low: 
     2304                min_ctrl.SetBackgroundColour("pink") 
     2305                min_ctrl.Refresh() 
     2306                max_ctrl.SetBackgroundColour("pink") 
     2307                max_ctrl.Refresh() 
     2308                #msg = "Invalid fit range for %s: min must be smaller than max"%name 
     2309                #wx.PostEvent(self._manager.parent, StatusEvent(status=msg)) 
     2310                continue 
     2311 
     2312            # Force value between min and max 
     2313            value = float(value_str) 
     2314            if value < low: 
     2315                value = low 
     2316                value_ctrl.SetValue(format_number(value)) 
     2317            elif value > high: 
     2318                value = high 
     2319                value_ctrl.SetValue(format_number(value)) 
    23382320 
    23392321            # Update value in model if it has changed 
     
    23412323                self.model.setParam(name, value) 
    23422324                is_modified = True 
     2325 
     2326            if name not in self.model.details.keys(): 
     2327                self.model.details[name] = ["", None, None] 
     2328            old_low, old_high = self.model.details[name][1:3] 
     2329            if old_low != low or old_high != high: 
     2330                # The configuration has changed but it won't change the 
     2331                # computed curve so no need to set is_modified to True 
     2332                #is_modified = True 
     2333                self.model.details[name][1:3] = low, high 
    23432334 
    23442335        return is_modified 
     
    25132504                self._disp_obj_dict[name1] = disp_model 
    25142505                self.model.set_dispersion(param_name, disp_model) 
    2515                 self.state._disp_obj_dict[name1] = disp_model.type 
     2506                self.state._disp_obj_dict[name1] = disp_model 
    25162507 
    25172508                value1 = str(format_number(self.model.getParam(name1), True)) 
     
    25362527                        item[0].Enable() 
    25372528                        item[2].Enable() 
    2538                         item[3].Show(True) 
    2539                         item[4].Show(True) 
    25402529                        item[5].Enable() 
    25412530                        item[6].Enable() 
     
    26302619        self._disp_obj_dict[name] = disp 
    26312620        self.model.set_dispersion(name.split('.')[0], disp) 
    2632         self.state._disp_obj_dict[name] = disp.type 
     2621        self.state._disp_obj_dict[name] = disp 
    26332622        self.values[name] = values 
    26342623        self.weights[name] = weights 
     
    26982687        :param disp_function: dispersion distr. function 
    26992688        """ 
     2689        # List of the poly_model name in the combobox 
     2690        list = ["RectangleDispersion", "ArrayDispersion", 
     2691                "LogNormalDispersion", "GaussianDispersion", 
     2692                "SchulzDispersion"] 
     2693 
    27002694        # Find the selection 
    2701         if disp_func is not None: 
    2702             try: 
    2703                 return POLYDISPERSITY_MODELS.values().index(disp_func.__class__) 
    2704             except ValueError: 
    2705                 pass  # Fall through to default class 
    2706         return POLYDISPERSITY_MODELS.keys().index('gaussian') 
     2695        try: 
     2696            selection = list.index(disp_func.__class__.__name__) 
     2697            return selection 
     2698        except: 
     2699            return 3 
    27072700 
    27082701    def on_reset_clicked(self, event): 
     
    32913284                    pd = content[name][1] 
    32923285                    if name.count('.') > 0: 
    3293                         # If this is parameter.width, then pd may be a floating 
    3294                         # point value or it may be an array distribution. 
    3295                         # Nothing to do for parameter.npts or parameter.nsigmas. 
    32963286                        try: 
    32973287                            float(pd) 
    3298                             if name.endswith('.npts'): 
    3299                                 pd = int(pd) 
    3300                         except Exception: 
     3288                        except: 
    33013289                            #continue 
    33023290                            if not pd and pd != '': 
     
    33063294                        # Only array func has pd == '' case. 
    33073295                        item[2].Enable(False) 
    3308                     else: 
    3309                         item[2].Enable(True) 
    33103296                    if item[2].__class__.__name__ == "ComboBox": 
    33113297                        if content[name][1] in self.model.fun_list: 
     
    33343320                        pd = value[0] 
    33353321                        if name.count('.') > 0: 
    3336                             # If this is parameter.width, then pd may be a floating 
    3337                             # point value or it may be an array distribution. 
    3338                             # Nothing to do for parameter.npts or parameter.nsigmas. 
    33393322                            try: 
    33403323                                pd = float(pd) 
    3341                                 if name.endswith('.npts'): 
    3342                                     pd = int(pd) 
    33433324                            except: 
    33443325                                #continue 
     
    33493330                            # Only array func has pd == '' case. 
    33503331                            item[2].Enable(False) 
    3351                         else: 
    3352                             item[2].Enable(True) 
    33533332                        if item[2].__class__.__name__ == "ComboBox": 
    33543333                            if value[0] in self.model.fun_list: 
     
    33703349        Helps get paste for poly function 
    33713350 
    3372         *item* is the parameter name 
    3373  
    3374         *value* depends on which parameter is being processed, and whether it 
    3375         has array polydispersity. 
    3376  
    3377         For parameters without array polydispersity: 
    3378  
    3379             parameter => ['FLOAT', ''] 
    3380             parameter.width => ['FLOAT', 'DISTRIBUTION', ''] 
    3381             parameter.npts => ['FLOAT', ''] 
    3382             parameter.nsigmas => ['FLOAT', ''] 
    3383  
    3384         For parameters with array polydispersity: 
    3385  
    3386             parameter => ['FLOAT', ''] 
    3387             parameter.width => ['FILENAME', 'array', [x1, ...], [w1, ...]] 
    3388             parameter.npts => ['FLOAT', ''] 
    3389             parameter.nsigmas => ['FLOAT', ''] 
    3390         """ 
    3391         # Do nothing if not setting polydispersity 
    3392         if len(value[1]) == 0: 
    3393             return 
    3394  
    3395         try: 
    3396             name = item[7].Name 
    3397             param_name = name.split('.')[0] 
    3398             item[7].SetValue(value[1]) 
    3399             selection = item[7].GetCurrentSelection() 
    3400             dispersity = item[7].GetClientData(selection) 
    3401             disp_model = dispersity() 
    3402  
    3403             if value[1] == 'array': 
    3404                 pd_vals = numpy.array(value[2]) 
    3405                 pd_weights = numpy.array(value[3]) 
    3406                 if len(pd_vals) == 0 or len(pd_vals) != len(pd_weights): 
    3407                     msg = ("bad array distribution parameters for %s" 
    3408                            % param_name) 
    3409                     raise ValueError(msg) 
    3410                 self._set_disp_cb(True, item=item) 
    3411                 self._set_array_disp_model(name=name, 
    3412                                            disp=disp_model, 
    3413                                            values=pd_vals, 
    3414                                            weights=pd_weights) 
    3415             else: 
    3416                 self._set_disp_cb(False, item=item) 
    3417                 self._disp_obj_dict[name] = disp_model 
    3418                 self.model.set_dispersion(param_name, disp_model) 
    3419                 self.state._disp_obj_dict[name] = disp_model.type 
    3420                 # TODO: It's not an array, why update values and weights? 
    3421                 self.model._persistency_dict[param_name] = \ 
    3422                     [self.values, self.weights] 
    3423                 self.state.values = self.values 
    3424                 self.state.weights = self.weights 
    3425  
    3426         except Exception: 
    3427             logging.error(traceback.format_exc()) 
    3428             print "Error in BasePage._paste_poly_help: %s" % \ 
    3429                                     sys.exc_info()[1] 
    3430  
    3431     def _set_disp_cb(self, isarray, item): 
     3351        :param item: Gui param items 
     3352        :param value: the values for parameter ctrols 
     3353        """ 
     3354        is_array = False 
     3355        if len(value[1]) > 0: 
     3356            # Only for dispersion func.s 
     3357            try: 
     3358                item[7].SetValue(value[1]) 
     3359                selection = item[7].GetCurrentSelection() 
     3360                name = item[7].Name 
     3361                param_name = name.split('.')[0] 
     3362                dispersity = item[7].GetClientData(selection) 
     3363                disp_model = dispersity() 
     3364                # Only for array disp 
     3365                try: 
     3366                    pd_vals = numpy.array(value[2]) 
     3367                    pd_weights = numpy.array(value[3]) 
     3368                    if len(pd_vals) > 0 and len(pd_vals) > 0: 
     3369                        if len(pd_vals) == len(pd_weights): 
     3370                            self._set_disp_array_cb(item=item) 
     3371                            self._set_array_disp_model(name=name, 
     3372                                                       disp=disp_model, 
     3373                                                       values=pd_vals, 
     3374                                                       weights=pd_weights) 
     3375                            is_array = True 
     3376                except Exception: 
     3377                    logging.error(traceback.format_exc()) 
     3378                if not is_array: 
     3379                    self._disp_obj_dict[name] = disp_model 
     3380                    self.model.set_dispersion(name, 
     3381                                              disp_model) 
     3382                    self.state._disp_obj_dict[name] = \ 
     3383                                              disp_model 
     3384                    self.model.set_dispersion(param_name, disp_model) 
     3385                    self.state.values = self.values 
     3386                    self.state.weights = self.weights 
     3387                    self.model._persistency_dict[param_name] = \ 
     3388                                            [self.state.values, 
     3389                                             self.state.weights] 
     3390 
     3391            except Exception: 
     3392                logging.error(traceback.format_exc()) 
     3393                print "Error in BasePage._paste_poly_help: %s" % \ 
     3394                                        sys.exc_info()[1] 
     3395 
     3396    def _set_disp_array_cb(self, item): 
    34323397        """ 
    34333398        Set cb for array disp 
    34343399        """ 
    3435         if isarray: 
    3436             item[0].SetValue(False) 
    3437             item[0].Enable(False) 
    3438             item[2].Enable(False) 
    3439             item[3].Show(False) 
    3440             item[4].Show(False) 
    3441             item[5].SetValue('') 
    3442             item[5].Enable(False) 
    3443             item[6].SetValue('') 
    3444             item[6].Enable(False) 
    3445         else: 
    3446             item[0].Enable() 
    3447             item[2].Enable() 
    3448             item[3].Show(True) 
    3449             item[4].Show(True) 
    3450             item[5].Enable() 
    3451             item[6].Enable() 
     3400        item[0].SetValue(False) 
     3401        item[0].Enable(False) 
     3402        item[2].Enable(False) 
     3403        item[3].Show(False) 
     3404        item[4].Show(False) 
     3405        item[5].SetValue('') 
     3406        item[5].Enable(False) 
     3407        item[6].SetValue('') 
     3408        item[6].Enable(False) 
    34523409 
    34533410    def update_pinhole_smear(self): 
Note: See TracChangeset for help on using the changeset viewer.