Changeset 277257f in sasview for src/sas/sascalc/fit/models.py


Ignore:
Timestamp:
Jul 5, 2017 5:28:55 PM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Branches:
master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, magnetic_scatt, release-4.2.2, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
Children:
1386b2f
Parents:
251ef684
Message:

clean up plugin-model handling code; preserve active parameter values when plugin is updated

File:
1 edited

Legend:

Unmodified
Added
Removed
  • src/sas/sascalc/fit/models.py

    r65f3930 r277257f  
    1616 
    1717from sasmodels.sasview_model import load_custom_model, load_standard_models 
     18from sasmodels.sasview_model import MultiplicationModel 
    1819 
    1920# Explicitly import from the pluginmodel module so that py2exe 
     
    2122# as the base class of plug-in models. 
    2223from .pluginmodel import Model1DPlugin 
    23  
    24 from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller 
    2524 
    2625logger = logging.getLogger(__name__) 
     
    9291    The plugin directory is located in the user's home directory. 
    9392    """ 
    94     dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) 
    95  
     93    path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) 
     94 
     95    # TODO: initializing ~/.sasview/plugin_models doesn't belong in sascalc 
    9696    # If the plugin directory doesn't exist, create it 
    97     if not os.path.isdir(dir): 
    98         os.makedirs(dir) 
    99  
    100     # Find paths needed 
    101     # TODO: remove unneeded try/except block 
    102     try: 
    103         # For source 
    104         if os.path.isdir(os.path.dirname(__file__)): 
    105             p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR) 
    106         else: 
    107             raise 
    108     except Exception: 
    109         # Check for data path next to exe/zip file. 
    110         #Look for maximum n_dir up of the current dir to find plugins dir 
    111         n_dir = 12 
    112         p_dir = None 
    113         f_dir = os.path.join(os.path.dirname(__file__)) 
    114         for i in range(n_dir): 
    115             if i > 1: 
    116                 f_dir, _ = os.path.split(f_dir) 
    117             plugin_path = os.path.join(f_dir, PLUGIN_DIR) 
    118             if os.path.isdir(plugin_path): 
    119                 p_dir = plugin_path 
    120                 break 
    121         if not p_dir: 
    122             raise 
    123     # Place example user models as needed 
    124     if os.path.isdir(p_dir): 
    125         for file in os.listdir(p_dir): 
    126             file_path = os.path.join(p_dir, file) 
    127             if os.path.isfile(file_path): 
    128                 if file.split(".")[-1] == 'py' and\ 
    129                     file.split(".")[0] != '__init__': 
    130                     if not os.path.isfile(os.path.join(dir, file)): 
    131                         shutil.copy(file_path, dir) 
    132  
    133     return dir 
    134  
    135  
    136 class ReportProblem: 
     97    if not os.path.isdir(path): 
     98        os.makedirs(path) 
     99    # TODO: should we be checking for new default models every time? 
     100    initialize_plugins_dir(path) 
     101    return path 
     102 
     103 
     104def initialize_plugins_dir(path): 
     105    # TODO: There are no default plugins 
     106    # TODO: Move default plugins beside sample data files 
     107    # TODO: Should not look for defaults above the root of the sasview install 
     108 
     109    # Walk up the tree looking for default plugin_models directory 
     110    base = os.path.abspath(os.path.dirname(__file__)) 
     111    for _ in range(12): 
     112        default_plugins_path = os.path.join(base, PLUGIN_DIR) 
     113        if os.path.isdir(default_plugins_path): 
     114            break 
     115        base, _ = os.path.split(base) 
     116    else: 
     117        logger.error("default plugins directory not found") 
     118        return 
     119 
     120    # Copy files from default plugins to the .sasview directory 
     121    # This may include c files, depending on the example. 
     122    # Note: files are never replaced, even if the default plugins are updated 
     123    for filename in os.listdir(default_plugins_path): 
     124        # skip __init__.py and all pyc files 
     125        if filename == "__init__.py" or filename.endswith('.pyc'): 
     126            continue 
     127        source = os.path.join(default_plugins_path, filename) 
     128        target = os.path.join(path, filename) 
     129        if os.path.isfile(source) and not os.path.isfile(target): 
     130            shutil.copy(source, target) 
     131 
     132 
     133class ReportProblem(object): 
    137134    """ 
    138135    Class to check for problems with specific values 
     
    161158 
    162159 
    163 def _find_models(): 
     160def find_plugin_models(): 
    164161    """ 
    165162    Find custom models 
    166163    """ 
    167164    # List of plugin objects 
    168     directory = find_plugins_dir() 
     165    plugins_dir = find_plugins_dir() 
    169166    # Go through files in plug-in directory 
    170     if not os.path.isdir(directory): 
    171         msg = "SasView couldn't locate Model plugin folder %r." % directory 
     167    if not os.path.isdir(plugins_dir): 
     168        msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir 
    172169        logger.warning(msg) 
    173170        return {} 
    174171 
    175     plugin_log("looking for models in: %s" % str(directory)) 
    176     # compile_file(directory)  #always recompile the folder plugin 
    177     logger.info("plugin model dir: %s", str(directory)) 
     172    plugin_log("looking for models in: %s" % plugins_dir) 
     173    # compile_file(plugins_dir)  #always recompile the folder plugin 
     174    logger.info("plugin model dir: %s", plugins_dir) 
    178175 
    179176    plugins = {} 
    180     for filename in os.listdir(directory): 
     177    for filename in os.listdir(plugins_dir): 
    181178        name, ext = os.path.splitext(filename) 
    182179        if ext == '.py' and not name == '__init__': 
    183             path = os.path.abspath(os.path.join(directory, filename)) 
     180            path = os.path.abspath(os.path.join(plugins_dir, filename)) 
    184181            try: 
    185182                model = load_custom_model(path) 
    186                 model.name = PLUGIN_NAME_BASE + model.name 
     183                # TODO: add [plug-in] tag to model name in sasview_model 
     184                if not model.name.startswith(PLUGIN_NAME_BASE): 
     185                    model.name = PLUGIN_NAME_BASE + model.name 
    187186                plugins[model.name] = model 
    188187            except Exception: 
     
    196195 
    197196 
    198 class ModelList(object): 
    199     """ 
    200     Contains dictionary of model and their type 
    201     """ 
     197class ModelManagerBase(object): 
     198    """ 
     199    Base class for the model manager 
     200    """ 
     201    #: mutable dictionary of models, continually updated to reflect the 
     202    #: current set of plugins 
     203    model_dictionary = None  # type: Dict[str, Model] 
     204    #: constant list of standard models 
     205    standard_models = None  # type: Dict[str, Model] 
     206    #: list of plugin models reset each time the plugin directory is queried 
     207    plugin_models = None  # type: Dict[str, Model] 
     208    #: timestamp on the plugin directory at the last plugin update 
     209    last_time_dir_modified = 0  # type: int 
     210 
    202211    def __init__(self): 
    203         """ 
    204         """ 
    205         self.mydict = {} 
    206  
    207     def set_list(self, name, mylist): 
    208         """ 
    209         :param name: the type of the list 
    210         :param mylist: the list to add 
    211  
    212         """ 
    213         if name not in self.mydict.keys(): 
    214             self.reset_list(name, mylist) 
    215  
    216     def reset_list(self, name, mylist): 
    217         """ 
    218         :param name: the type of the list 
    219         :param mylist: the list to add 
    220         """ 
    221         self.mydict[name] = mylist 
    222  
    223     def get_list(self): 
    224         """ 
    225         return all the list stored in a dictionary object 
    226         """ 
    227         return self.mydict 
    228  
    229  
    230 class ModelManagerBase(object): 
    231     """ 
    232     Base class for the model manager 
    233     """ 
    234     ## external dict for models 
    235     model_combobox = ModelList() 
    236     ## Dictionary of form factor models 
    237     form_factor_dict = {} 
    238     ## dictionary of structure factor models 
    239     struct_factor_dict = {} 
    240     ##list of structure factors 
    241     struct_list = [] 
    242     ##list of model allowing multiplication by a structure factor 
    243     multiplication_factor = [] 
    244     ##list of multifunctional shapes (i.e. that have user defined number of levels 
    245     multi_func_list = [] 
    246     ## list of added models -- currently python models found in the plugin dir. 
    247     plugins = [] 
    248     ## Event owner (guiframe) 
    249     event_owner = None 
    250     last_time_dir_modified = 0 
    251  
    252     def __init__(self): 
     212        # the model dictionary is allocated at the start and updated to 
     213        # reflect the current list of models.  Be sure to clear it rather 
     214        # than reassign to it. 
    253215        self.model_dictionary = {} 
    254         self.stored_plugins = {} 
    255         self._getModelList() 
    256  
    257     def findModels(self): 
    258         """ 
    259         find  plugin model in directory of plugin .recompile all file 
    260         in the directory if file were modified 
    261         """ 
    262         temp = {} 
    263         if self.is_changed(): 
    264             return  _find_models() 
    265         logger.info("plugin model : %s", str(temp)) 
    266         return temp 
    267  
    268     def _getModelList(self): 
    269         """ 
    270         List of models we want to make available by default 
    271         for this application 
    272  
    273         :return: the next free event ID following the new menu events 
    274  
    275         """ 
    276  
    277         # regular model names only 
    278         self.model_name_list = [] 
    279216 
    280217        #Build list automagically from sasmodels package 
    281         for model in load_standard_models(): 
    282             self.model_dictionary[model.name] = model 
    283             if model.is_structure_factor: 
    284                 self.struct_list.append(model) 
    285             if model.is_form_factor: 
    286                 self.multiplication_factor.append(model) 
    287             if model.is_multiplicity_model: 
    288                 self.multi_func_list.append(model) 
    289             else: 
    290                 self.model_name_list.append(model.name) 
    291  
    292         #Looking for plugins 
    293         self.stored_plugins = self.findModels() 
    294         self.plugins = self.stored_plugins.values() 
    295         for name, plug in self.stored_plugins.iteritems(): 
    296             self.model_dictionary[name] = plug 
    297  
    298         self._get_multifunc_models() 
    299  
    300         return 0 
    301  
    302     def is_changed(self): 
     218        self.standard_models = {model.name: model 
     219                                for model in load_standard_models()} 
     220        # Look for plugins 
     221        self.plugins_reset() 
     222 
     223    def _is_plugin_dir_changed(self): 
    303224        """ 
    304225        check the last time the plugin dir has changed and return true 
     
    308229        plugin_dir = find_plugins_dir() 
    309230        if os.path.isdir(plugin_dir): 
    310             temp = os.path.getmtime(plugin_dir) 
    311             if  self.last_time_dir_modified != temp: 
     231            mod_time = os.path.getmtime(plugin_dir) 
     232            if  self.last_time_dir_modified != mod_time: 
    312233                is_modified = True 
    313                 self.last_time_dir_modified = temp 
     234                self.last_time_dir_modified = mod_time 
    314235 
    315236        return is_modified 
    316237 
    317     def update(self): 
     238    def composable_models(self): 
     239        """ 
     240        return list of standard models that can be used in sum/multiply 
     241        """ 
     242        # TODO: should scan plugin models in addition to standard models 
     243        # and update model_editor so that it doesn't add plugins to the list 
     244        return [model.name for model in self.standard_models.values() 
     245                if not model.is_multiplicity_model] 
     246 
     247    def plugins_update(self): 
    318248        """ 
    319249        return a dictionary of model if 
    320250        new models were added else return empty dictionary 
    321251        """ 
    322         new_plugins = self.findModels() 
    323         if len(new_plugins) > 0: 
    324             for name, plug in  new_plugins.iteritems(): 
    325                 if name not in self.stored_plugins.keys(): 
    326                     self.stored_plugins[name] = plug 
    327                     self.plugins.append(plug) 
    328                     self.model_dictionary[name] = plug 
    329             self.model_combobox.set_list("Plugin Models", self.plugins) 
    330             return self.model_combobox.get_list() 
    331         else: 
    332             return {} 
     252        return self.plugins_reset() 
     253        #if self._is_plugin_dir_changed(): 
     254        #    return self.plugins_reset() 
     255        #else: 
     256        #    return {} 
    333257 
    334258    def plugins_reset(self): 
     
    336260        return a dictionary of model 
    337261        """ 
    338         self.plugins = [] 
    339         new_plugins = _find_models() 
    340         for name, plug in  new_plugins.iteritems(): 
    341             for stored_name, stored_plug in self.stored_plugins.iteritems(): 
    342                 if name == stored_name: 
    343                     del self.stored_plugins[name] 
    344                     del self.model_dictionary[name] 
    345                     break 
    346             self.stored_plugins[name] = plug 
    347             self.plugins.append(plug) 
    348             self.model_dictionary[name] = plug 
    349  
    350         self.model_combobox.reset_list("Plugin Models", self.plugins) 
    351         return self.model_combobox.get_list() 
    352  
    353     def _on_model(self, evt): 
    354         """ 
    355         React to a model menu event 
    356  
    357         :param event: wx menu event 
    358  
    359         """ 
    360         if int(evt.GetId()) in self.form_factor_dict.keys(): 
    361             from sasmodels.sasview_model import MultiplicationModel 
    362             self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel 
    363             model1, model2 = self.form_factor_dict[int(evt.GetId())] 
    364             model = MultiplicationModel(model1, model2) 
    365         else: 
    366             model = self.struct_factor_dict[str(evt.GetId())]() 
    367  
    368  
    369     def _get_multifunc_models(self): 
    370         """ 
    371         Get the multifunctional models 
    372         """ 
    373         items = [item for item in self.plugins if item.is_multiplicity_model] 
    374         self.multi_func_list = items 
     262        self.plugin_models = find_plugin_models() 
     263        self.model_dictionary.clear() 
     264        self.model_dictionary.update(self.standard_models) 
     265        self.model_dictionary.update(self.plugin_models) 
     266        return self.get_model_list() 
    375267 
    376268    def get_model_list(self): 
    377269        """ 
    378         return dictionary of models for fitpanel use 
    379  
     270        return dictionary of classified models 
     271 
     272        *Structure Factors* are the structure factor models 
     273        *Multi-Functions* are the multiplicity models 
     274        *Plugin Models* are the plugin models 
     275 
     276        Note that a model can be both a plugin and a structure factor or 
     277        multiplicity model. 
    380278        """ 
    381279        ## Model_list now only contains attribute lists not category list. 
     
    387285        ## -PDB   April 26, 2014 
    388286 
    389 #        self.model_combobox.set_list("Shapes", self.shape_list) 
    390 #        self.model_combobox.set_list("Shape-Independent", 
    391 #                                     self.shape_indep_list) 
    392         self.model_combobox.set_list("Structure Factors", self.struct_list) 
    393         self.model_combobox.set_list("Plugin Models", self.plugins) 
    394         self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor) 
    395         self.model_combobox.set_list("multiplication", 
    396                                      self.multiplication_factor) 
    397         self.model_combobox.set_list("Multi-Functions", self.multi_func_list) 
    398         return self.model_combobox.get_list() 
    399  
    400     def get_model_name_list(self): 
    401         """ 
    402         return regular model name list 
    403         """ 
    404         return self.model_name_list 
    405  
    406     def get_model_dictionary(self): 
    407         """ 
    408         return dictionary linking model names to objects 
    409         """ 
    410         return self.model_dictionary 
     287 
     288        # Classify models 
     289        structure_factors = [] 
     290        multiplicity_models = [] 
     291        for model in self.model_dictionary.values(): 
     292            # Old style models don't have is_structure_factor attribute 
     293            if getattr(model, 'is_structure_factor', False): 
     294                structure_factors.append(model) 
     295            if model.is_multiplicity_model: 
     296                multiplicity_models.append(model) 
     297        plugin_models = list(self.plugin_models.values()) 
     298 
     299        return { 
     300            "Structure Factors": structure_factors, 
     301            "Plugin Models": plugin_models, 
     302            "Multi-Functions": multiplicity_models, 
     303        } 
    411304 
    412305 
    413306class ModelManager(object): 
    414307    """ 
    415     implement model 
     308    manage the list of available models 
    416309    """ 
    417310    base = None  # type: ModelManagerBase() 
     
    419312    def __init__(self): 
    420313        if ModelManager.base is None: 
    421             self.base = ModelManagerBase() 
     314            ModelManager.base = ModelManagerBase() 
    422315 
    423316    def cat_model_list(self): 
    424         models = self.base.model_dictionary 
    425         retval = [model for model_name, model in models.items() 
    426                   if model_name not in self.base.stored_plugins] 
    427         return retval 
    428  
    429     def findModels(self): 
    430         return self.base.findModels() 
    431  
    432     def _getModelList(self): 
    433         return self.base._getModelList() 
    434  
    435     def is_changed(self): 
    436         return self.base.is_changed() 
     317        return list(self.base.standard_models.values()) 
    437318 
    438319    def update(self): 
    439         return self.base.update() 
     320        return self.base.plugins_update() 
    440321 
    441322    def plugins_reset(self): 
    442323        return self.base.plugins_reset() 
    443324 
    444     #def populate_menu(self, modelmenu, event_owner): 
    445     #    return self.base.populate_menu(modelmenu, event_owner) 
    446  
    447     def _on_model(self, evt): 
    448         return self.base._on_model(evt) 
    449  
    450     def _get_multifunc_models(self): 
    451         return self.base._get_multifunc_models() 
    452  
    453325    def get_model_list(self): 
    454326        return self.base.get_model_list() 
    455327 
    456     def get_model_name_list(self): 
    457         return self.base.get_model_name_list() 
     328    def composable_models(self): 
     329        return self.base.composable_models() 
    458330 
    459331    def get_model_dictionary(self): 
    460         return self.base.get_model_dictionary() 
     332        return self.base.model_dictionary 
Note: See TracChangeset for help on using the changeset viewer.