Changeset 277257f in sasview for src/sas/sascalc/fit/models.py
- Timestamp:
- Jul 5, 2017 5:28:55 PM (7 years ago)
- Branches:
- master, ESS_GUI, ESS_GUI_Docs, ESS_GUI_batch_fitting, ESS_GUI_bumps_abstraction, ESS_GUI_iss1116, ESS_GUI_iss879, ESS_GUI_iss959, ESS_GUI_opencl, ESS_GUI_ordering, ESS_GUI_sync_sascalc, magnetic_scatt, release-4.2.2, ticket-1009, ticket-1094-headless, ticket-1242-2d-resolution, ticket-1243, ticket-1249, ticket885, unittest-saveload
- Children:
- 1386b2f
- Parents:
- 251ef684
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
src/sas/sascalc/fit/models.py
r65f3930 r277257f 16 16 17 17 from sasmodels.sasview_model import load_custom_model, load_standard_models 18 from sasmodels.sasview_model import MultiplicationModel 18 19 19 20 # Explicitly import from the pluginmodel module so that py2exe … … 21 22 # as the base class of plug-in models. 22 23 from .pluginmodel import Model1DPlugin 23 24 from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller25 24 26 25 logger = logging.getLogger(__name__) … … 92 91 The plugin directory is located in the user's home directory. 93 92 """ 94 dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) 95 93 path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) 94 95 # TODO: initializing ~/.sasview/plugin_models doesn't belong in sascalc 96 96 # If the plugin directory doesn't exist, create it 97 if not os.path.isdir(dir): 98 os.makedirs(dir) 99 100 # Find paths needed 101 # TODO: remove unneeded try/except block 102 try: 103 # For source 104 if os.path.isdir(os.path.dirname(__file__)): 105 p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR) 106 else: 107 raise 108 except Exception: 109 # Check for data path next to exe/zip file. 110 #Look for maximum n_dir up of the current dir to find plugins dir 111 n_dir = 12 112 p_dir = None 113 f_dir = os.path.join(os.path.dirname(__file__)) 114 for i in range(n_dir): 115 if i > 1: 116 f_dir, _ = os.path.split(f_dir) 117 plugin_path = os.path.join(f_dir, PLUGIN_DIR) 118 if os.path.isdir(plugin_path): 119 p_dir = plugin_path 120 break 121 if not p_dir: 122 raise 123 # Place example user models as needed 124 if os.path.isdir(p_dir): 125 for file in os.listdir(p_dir): 126 file_path = os.path.join(p_dir, file) 127 if os.path.isfile(file_path): 128 if file.split(".")[-1] == 'py' and\ 129 file.split(".")[0] != '__init__': 130 if not os.path.isfile(os.path.join(dir, file)): 131 shutil.copy(file_path, dir) 132 133 return dir 134 135 136 class ReportProblem: 97 if not os.path.isdir(path): 98 os.makedirs(path) 99 # TODO: should we be checking for new default models every time? 100 initialize_plugins_dir(path) 101 return path 102 103 104 def initialize_plugins_dir(path): 105 # TODO: There are no default plugins 106 # TODO: Move default plugins beside sample data files 107 # TODO: Should not look for defaults above the root of the sasview install 108 109 # Walk up the tree looking for default plugin_models directory 110 base = os.path.abspath(os.path.dirname(__file__)) 111 for _ in range(12): 112 default_plugins_path = os.path.join(base, PLUGIN_DIR) 113 if os.path.isdir(default_plugins_path): 114 break 115 base, _ = os.path.split(base) 116 else: 117 logger.error("default plugins directory not found") 118 return 119 120 # Copy files from default plugins to the .sasview directory 121 # This may include c files, depending on the example. 122 # Note: files are never replaced, even if the default plugins are updated 123 for filename in os.listdir(default_plugins_path): 124 # skip __init__.py and all pyc files 125 if filename == "__init__.py" or filename.endswith('.pyc'): 126 continue 127 source = os.path.join(default_plugins_path, filename) 128 target = os.path.join(path, filename) 129 if os.path.isfile(source) and not os.path.isfile(target): 130 shutil.copy(source, target) 131 132 133 class ReportProblem(object): 137 134 """ 138 135 Class to check for problems with specific values … … 161 158 162 159 163 def _find_models():160 def find_plugin_models(): 164 161 """ 165 162 Find custom models 166 163 """ 167 164 # List of plugin objects 168 directory= find_plugins_dir()165 plugins_dir = find_plugins_dir() 169 166 # Go through files in plug-in directory 170 if not os.path.isdir( directory):171 msg = "SasView couldn't locate Model plugin folder %r." % directory167 if not os.path.isdir(plugins_dir): 168 msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir 172 169 logger.warning(msg) 173 170 return {} 174 171 175 plugin_log("looking for models in: %s" % str(directory))176 # compile_file( directory) #always recompile the folder plugin177 logger.info("plugin model dir: %s", str(directory))172 plugin_log("looking for models in: %s" % plugins_dir) 173 # compile_file(plugins_dir) #always recompile the folder plugin 174 logger.info("plugin model dir: %s", plugins_dir) 178 175 179 176 plugins = {} 180 for filename in os.listdir( directory):177 for filename in os.listdir(plugins_dir): 181 178 name, ext = os.path.splitext(filename) 182 179 if ext == '.py' and not name == '__init__': 183 path = os.path.abspath(os.path.join( directory, filename))180 path = os.path.abspath(os.path.join(plugins_dir, filename)) 184 181 try: 185 182 model = load_custom_model(path) 186 model.name = PLUGIN_NAME_BASE + model.name 183 # TODO: add [plug-in] tag to model name in sasview_model 184 if not model.name.startswith(PLUGIN_NAME_BASE): 185 model.name = PLUGIN_NAME_BASE + model.name 187 186 plugins[model.name] = model 188 187 except Exception: … … 196 195 197 196 198 class ModelList(object): 199 """ 200 Contains dictionary of model and their type 201 """ 197 class ModelManagerBase(object): 198 """ 199 Base class for the model manager 200 """ 201 #: mutable dictionary of models, continually updated to reflect the 202 #: current set of plugins 203 model_dictionary = None # type: Dict[str, Model] 204 #: constant list of standard models 205 standard_models = None # type: Dict[str, Model] 206 #: list of plugin models reset each time the plugin directory is queried 207 plugin_models = None # type: Dict[str, Model] 208 #: timestamp on the plugin directory at the last plugin update 209 last_time_dir_modified = 0 # type: int 210 202 211 def __init__(self): 203 """ 204 """ 205 self.mydict = {} 206 207 def set_list(self, name, mylist): 208 """ 209 :param name: the type of the list 210 :param mylist: the list to add 211 212 """ 213 if name not in self.mydict.keys(): 214 self.reset_list(name, mylist) 215 216 def reset_list(self, name, mylist): 217 """ 218 :param name: the type of the list 219 :param mylist: the list to add 220 """ 221 self.mydict[name] = mylist 222 223 def get_list(self): 224 """ 225 return all the list stored in a dictionary object 226 """ 227 return self.mydict 228 229 230 class ModelManagerBase(object): 231 """ 232 Base class for the model manager 233 """ 234 ## external dict for models 235 model_combobox = ModelList() 236 ## Dictionary of form factor models 237 form_factor_dict = {} 238 ## dictionary of structure factor models 239 struct_factor_dict = {} 240 ##list of structure factors 241 struct_list = [] 242 ##list of model allowing multiplication by a structure factor 243 multiplication_factor = [] 244 ##list of multifunctional shapes (i.e. that have user defined number of levels 245 multi_func_list = [] 246 ## list of added models -- currently python models found in the plugin dir. 247 plugins = [] 248 ## Event owner (guiframe) 249 event_owner = None 250 last_time_dir_modified = 0 251 252 def __init__(self): 212 # the model dictionary is allocated at the start and updated to 213 # reflect the current list of models. Be sure to clear it rather 214 # than reassign to it. 253 215 self.model_dictionary = {} 254 self.stored_plugins = {}255 self._getModelList()256 257 def findModels(self):258 """259 find plugin model in directory of plugin .recompile all file260 in the directory if file were modified261 """262 temp = {}263 if self.is_changed():264 return _find_models()265 logger.info("plugin model : %s", str(temp))266 return temp267 268 def _getModelList(self):269 """270 List of models we want to make available by default271 for this application272 273 :return: the next free event ID following the new menu events274 275 """276 277 # regular model names only278 self.model_name_list = []279 216 280 217 #Build list automagically from sasmodels package 281 for model in load_standard_models(): 282 self.model_dictionary[model.name] = model 283 if model.is_structure_factor: 284 self.struct_list.append(model) 285 if model.is_form_factor: 286 self.multiplication_factor.append(model) 287 if model.is_multiplicity_model: 288 self.multi_func_list.append(model) 289 else: 290 self.model_name_list.append(model.name) 291 292 #Looking for plugins 293 self.stored_plugins = self.findModels() 294 self.plugins = self.stored_plugins.values() 295 for name, plug in self.stored_plugins.iteritems(): 296 self.model_dictionary[name] = plug 297 298 self._get_multifunc_models() 299 300 return 0 301 302 def is_changed(self): 218 self.standard_models = {model.name: model 219 for model in load_standard_models()} 220 # Look for plugins 221 self.plugins_reset() 222 223 def _is_plugin_dir_changed(self): 303 224 """ 304 225 check the last time the plugin dir has changed and return true … … 308 229 plugin_dir = find_plugins_dir() 309 230 if os.path.isdir(plugin_dir): 310 temp= os.path.getmtime(plugin_dir)311 if self.last_time_dir_modified != temp:231 mod_time = os.path.getmtime(plugin_dir) 232 if self.last_time_dir_modified != mod_time: 312 233 is_modified = True 313 self.last_time_dir_modified = temp234 self.last_time_dir_modified = mod_time 314 235 315 236 return is_modified 316 237 317 def update(self): 238 def composable_models(self): 239 """ 240 return list of standard models that can be used in sum/multiply 241 """ 242 # TODO: should scan plugin models in addition to standard models 243 # and update model_editor so that it doesn't add plugins to the list 244 return [model.name for model in self.standard_models.values() 245 if not model.is_multiplicity_model] 246 247 def plugins_update(self): 318 248 """ 319 249 return a dictionary of model if 320 250 new models were added else return empty dictionary 321 251 """ 322 new_plugins = self.findModels() 323 if len(new_plugins) > 0: 324 for name, plug in new_plugins.iteritems(): 325 if name not in self.stored_plugins.keys(): 326 self.stored_plugins[name] = plug 327 self.plugins.append(plug) 328 self.model_dictionary[name] = plug 329 self.model_combobox.set_list("Plugin Models", self.plugins) 330 return self.model_combobox.get_list() 331 else: 332 return {} 252 return self.plugins_reset() 253 #if self._is_plugin_dir_changed(): 254 # return self.plugins_reset() 255 #else: 256 # return {} 333 257 334 258 def plugins_reset(self): … … 336 260 return a dictionary of model 337 261 """ 338 self.plugins = [] 339 new_plugins = _find_models() 340 for name, plug in new_plugins.iteritems(): 341 for stored_name, stored_plug in self.stored_plugins.iteritems(): 342 if name == stored_name: 343 del self.stored_plugins[name] 344 del self.model_dictionary[name] 345 break 346 self.stored_plugins[name] = plug 347 self.plugins.append(plug) 348 self.model_dictionary[name] = plug 349 350 self.model_combobox.reset_list("Plugin Models", self.plugins) 351 return self.model_combobox.get_list() 352 353 def _on_model(self, evt): 354 """ 355 React to a model menu event 356 357 :param event: wx menu event 358 359 """ 360 if int(evt.GetId()) in self.form_factor_dict.keys(): 361 from sasmodels.sasview_model import MultiplicationModel 362 self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel 363 model1, model2 = self.form_factor_dict[int(evt.GetId())] 364 model = MultiplicationModel(model1, model2) 365 else: 366 model = self.struct_factor_dict[str(evt.GetId())]() 367 368 369 def _get_multifunc_models(self): 370 """ 371 Get the multifunctional models 372 """ 373 items = [item for item in self.plugins if item.is_multiplicity_model] 374 self.multi_func_list = items 262 self.plugin_models = find_plugin_models() 263 self.model_dictionary.clear() 264 self.model_dictionary.update(self.standard_models) 265 self.model_dictionary.update(self.plugin_models) 266 return self.get_model_list() 375 267 376 268 def get_model_list(self): 377 269 """ 378 return dictionary of models for fitpanel use 379 270 return dictionary of classified models 271 272 *Structure Factors* are the structure factor models 273 *Multi-Functions* are the multiplicity models 274 *Plugin Models* are the plugin models 275 276 Note that a model can be both a plugin and a structure factor or 277 multiplicity model. 380 278 """ 381 279 ## Model_list now only contains attribute lists not category list. … … 387 285 ## -PDB April 26, 2014 388 286 389 # self.model_combobox.set_list("Shapes", self.shape_list) 390 # self.model_combobox.set_list("Shape-Independent", 391 # self.shape_indep_list) 392 self.model_combobox.set_list("Structure Factors", self.struct_list) 393 self.model_combobox.set_list("Plugin Models", self.plugins) 394 self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor) 395 self.model_combobox.set_list("multiplication", 396 self.multiplication_factor) 397 self.model_combobox.set_list("Multi-Functions", self.multi_func_list) 398 return self.model_combobox.get_list() 399 400 def get_model_name_list(self): 401 """ 402 return regular model name list 403 """ 404 return self.model_name_list 405 406 def get_model_dictionary(self): 407 """ 408 return dictionary linking model names to objects 409 """ 410 return self.model_dictionary 287 288 # Classify models 289 structure_factors = [] 290 multiplicity_models = [] 291 for model in self.model_dictionary.values(): 292 # Old style models don't have is_structure_factor attribute 293 if getattr(model, 'is_structure_factor', False): 294 structure_factors.append(model) 295 if model.is_multiplicity_model: 296 multiplicity_models.append(model) 297 plugin_models = list(self.plugin_models.values()) 298 299 return { 300 "Structure Factors": structure_factors, 301 "Plugin Models": plugin_models, 302 "Multi-Functions": multiplicity_models, 303 } 411 304 412 305 413 306 class ModelManager(object): 414 307 """ 415 implement model308 manage the list of available models 416 309 """ 417 310 base = None # type: ModelManagerBase() … … 419 312 def __init__(self): 420 313 if ModelManager.base is None: 421 self.base = ModelManagerBase()314 ModelManager.base = ModelManagerBase() 422 315 423 316 def cat_model_list(self): 424 models = self.base.model_dictionary 425 retval = [model for model_name, model in models.items() 426 if model_name not in self.base.stored_plugins] 427 return retval 428 429 def findModels(self): 430 return self.base.findModels() 431 432 def _getModelList(self): 433 return self.base._getModelList() 434 435 def is_changed(self): 436 return self.base.is_changed() 317 return list(self.base.standard_models.values()) 437 318 438 319 def update(self): 439 return self.base. update()320 return self.base.plugins_update() 440 321 441 322 def plugins_reset(self): 442 323 return self.base.plugins_reset() 443 324 444 #def populate_menu(self, modelmenu, event_owner):445 # return self.base.populate_menu(modelmenu, event_owner)446 447 def _on_model(self, evt):448 return self.base._on_model(evt)449 450 def _get_multifunc_models(self):451 return self.base._get_multifunc_models()452 453 325 def get_model_list(self): 454 326 return self.base.get_model_list() 455 327 456 def get_model_name_list(self):457 return self.base. get_model_name_list()328 def composable_models(self): 329 return self.base.composable_models() 458 330 459 331 def get_model_dictionary(self): 460 return self.base. get_model_dictionary()332 return self.base.model_dictionary
Note: See TracChangeset
for help on using the changeset viewer.