source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ 13374be

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 13374be was 13374be, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

Merge branch 'master' into 4_1_issues

  • Property mode set to 100644
File size: 16.3 KB
RevLine 
[f0d720b]1"""
2    Utilities to manage models
3"""
[a1b8fee]4from __future__ import print_function
5
[7673ecd]6import traceback
[f0d720b]7import os
8import sys
9import os.path
10# Time is needed by the log method
11import time
[7673ecd]12import datetime
[f0d720b]13import logging
14import py_compile
15import shutil
[33dc18f]16from copy import copy
[f0d720b]17# Explicitly import from the pluginmodel module so that py2exe
18# places it in the distribution. The Model1DPlugin class is used
19# as the base class of plug-in models.
[9e531f2]20from sas.sascalc.fit.pluginmodel import Model1DPlugin
[d85c194]21from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
[dcdca68]22from sasmodels.sasview_model import load_custom_model, load_standard_models
[05b8e01]23from sas.sasgui.perspectives.fitting.fitpage import CUSTOM_MODEL
[f66d9d1]24
[463e7ffc]25logger = logging.getLogger(__name__)
[c155a16]26
[f66d9d1]27
[f0d720b]28PLUGIN_DIR = 'plugin_models'
[7673ecd]29PLUGIN_LOG = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR,
30                          "plugins.log")
[0de74af]31PLUGIN_NAME_BASE = '[plug-in] '
[f0d720b]32
33def get_model_python_path():
34    """
[dcdca68]35    Returns the python path for a model
[f0d720b]36    """
37    return os.path.dirname(__file__)
38
39
[7673ecd]40def plugin_log(message):
[f0d720b]41    """
[dcdca68]42    Log a message in a file located in the user's home directory
[f0d720b]43    """
[7673ecd]44    out = open(PLUGIN_LOG, 'a')
45    now = time.time()
46    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
47    out.write("%s: %s\n" % (stamp, message))
[f0d720b]48    out.close()
49
50
51def _check_plugin(model, name):
52    """
53    Do some checking before model adding plugins in the list
54
55    :param model: class model to add into the plugin list
56    :param name:name of the module plugin
57
58    :return model: model if valid model or None if not valid
59
60    """
61    #Check if the plugin is of type Model1DPlugin
62    if not issubclass(model, Model1DPlugin):
63        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
[7673ecd]64        plugin_log(msg)
[f0d720b]65        return None
66    if model.__name__ != "Model":
67        msg = "Plugin %s class name must be Model \n" % str(name)
[7673ecd]68        plugin_log(msg)
[f0d720b]69        return None
70    try:
71        new_instance = model()
72    except:
73        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
74                                                             str(sys.exc_type),
75                                                             sys.exc_info()[1])
[7673ecd]76        plugin_log(msg)
[f0d720b]77        return None
78
79    if hasattr(new_instance, "function"):
80        try:
81            value = new_instance.function()
82        except:
83            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
84                    (str(name), str(sys.exc_type), sys.exc_info()[1])
[7673ecd]85            plugin_log(msg)
[f0d720b]86            return None
87    else:
88        msg = "Plugin  %s needs a method called function \n" % str(name)
[7673ecd]89        plugin_log(msg)
[f0d720b]90        return None
91    return model
92
93
94def find_plugins_dir():
95    """
[dcdca68]96    Find path of the plugins directory.
97    The plugin directory is located in the user's home directory.
[f0d720b]98    """
99    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
100
101    # If the plugin directory doesn't exist, create it
102    if not os.path.isdir(dir):
103        os.makedirs(dir)
104
105    # Find paths needed
106    try:
107        # For source
108        if os.path.isdir(os.path.dirname(__file__)):
109            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
110        else:
111            raise
112    except:
113        # Check for data path next to exe/zip file.
114        #Look for maximum n_dir up of the current dir to find plugins dir
115        n_dir = 12
116        p_dir = None
117        f_dir = os.path.join(os.path.dirname(__file__))
118        for i in range(n_dir):
119            if i > 1:
120                f_dir, _ = os.path.split(f_dir)
121            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
122            if os.path.isdir(plugin_path):
123                p_dir = plugin_path
124                break
125        if not p_dir:
126            raise
127    # Place example user models as needed
128    if os.path.isdir(p_dir):
129        for file in os.listdir(p_dir):
130            file_path = os.path.join(p_dir, file)
131            if os.path.isfile(file_path):
132                if file.split(".")[-1] == 'py' and\
133                    file.split(".")[0] != '__init__':
134                    if not os.path.isfile(os.path.join(dir, file)):
135                        shutil.copy(file_path, dir)
136
137    return dir
138
139
140class ReportProblem:
141    """
[dcdca68]142    Class to check for problems with specific values
[f0d720b]143    """
144    def __nonzero__(self):
[7673ecd]145        type, value, tb = sys.exc_info()
[f0d720b]146        if type is not None and issubclass(type, py_compile.PyCompileError):
[9c3d784]147            print("Problem with", repr(value))
[7673ecd]148            raise type, value, tb
[f0d720b]149        return 1
150
151report_problem = ReportProblem()
152
153
154def compile_file(dir):
155    """
156    Compile a py file
157    """
158    try:
159        import compileall
[d85b0c7]160        compileall.compile_dir(dir=dir, ddir=dir, force=0,
[f0d720b]161                               quiet=report_problem)
162    except:
163        return sys.exc_info()[1]
164    return None
165
166
[56a282c]167def _find_models():
[7673ecd]168    """
169    Find custom models
170    """
[f0d720b]171    # List of plugin objects
[f1e409f]172    directory = find_plugins_dir()
[f0d720b]173    # Go through files in plug-in directory
[f1e409f]174    if not os.path.isdir(directory):
175        msg = "SasView couldn't locate Model plugin folder %r." % directory
[c155a16]176        logger.warning(msg)
[7673ecd]177        return {}
178
[f1e409f]179    plugin_log("looking for models in: %s" % str(directory))
180    # compile_file(directory)  #always recompile the folder plugin
181    logger.info("plugin model dir: %s" % str(directory))
[7673ecd]182
183    plugins = {}
[f1e409f]184    for filename in os.listdir(directory):
[7673ecd]185        name, ext = os.path.splitext(filename)
186        if ext == '.py' and not name == '__init__':
[f1e409f]187            path = os.path.abspath(os.path.join(directory, filename))
[7673ecd]188            try:
189                model = load_custom_model(path)
[b1c2011]190                if not model.name.count(PLUGIN_NAME_BASE):
191                    model.name = PLUGIN_NAME_BASE + model.name
[6fb559d]192                plugins[model.name] = model
[7673ecd]193            except Exception:
194                msg = traceback.format_exc()
195                msg += "\nwhile accessing model in %r" % path
196                plugin_log(msg)
[c155a16]197                logger.warning("Failed to load plugin %r. See %s for details"
[8d891d1]198                               % (path, PLUGIN_LOG))
199
[f0d720b]200    return plugins
201
202
203class ModelList(object):
204    """
205    Contains dictionary of model and their type
206    """
207    def __init__(self):
208        """
209        """
210        self.mydict = {}
211
212    def set_list(self, name, mylist):
213        """
214        :param name: the type of the list
215        :param mylist: the list to add
216
217        """
218        if name not in self.mydict.keys():
219            self.reset_list(name, mylist)
220
221    def reset_list(self, name, mylist):
222        """
223        :param name: the type of the list
224        :param mylist: the list to add
225        """
226        self.mydict[name] = mylist
227
228    def get_list(self):
229        """
230        return all the list stored in a dictionary object
231        """
232        return self.mydict
233
234
235class ModelManagerBase:
236    """
[dcdca68]237    Base class for the model manager
[f0d720b]238    """
239    ## external dict for models
240    model_combobox = ModelList()
241    ## Dictionary of form factor models
242    form_factor_dict = {}
243    ## dictionary of structure factor models
244    struct_factor_dict = {}
245    ##list of structure factors
246    struct_list = []
247    ##list of model allowing multiplication by a structure factor
248    multiplication_factor = []
249    ##list of multifunctional shapes (i.e. that have user defined number of levels
250    multi_func_list = []
251    ## list of added models -- currently python models found in the plugin dir.
252    plugins = []
253    ## Event owner (guiframe)
254    event_owner = None
255    last_time_dir_modified = 0
256
257    def __init__(self):
258        self.model_dictionary = {}
259        self.stored_plugins = {}
260        self._getModelList()
261
262    def findModels(self):
263        """
264        find  plugin model in directory of plugin .recompile all file
265        in the directory if file were modified
266        """
267        temp = {}
268        if self.is_changed():
[146c669]269            temp =  _find_models()
[f3cfe7a]270            self.last_time_dir_modified = time.time()
271            return temp
[c155a16]272        logger.info("plugin model : %s" % str(temp))
[f0d720b]273        return temp
274
275    def _getModelList(self):
276        """
277        List of models we want to make available by default
278        for this application
279
280        :return: the next free event ID following the new menu events
281
282        """
283
[908f090]284        # Regular model names only
[f0d720b]285        self.model_name_list = []
286
[908f090]287        # Build list automagically from sasmodels package
[dcdca68]288        for model in load_standard_models():
[cb4ef58]289            self.model_dictionary[model.name] = model
290            if model.is_structure_factor:
[a0b355b]291                self.struct_list.append(model)
[cb4ef58]292            if model.is_form_factor:
293                self.multiplication_factor.append(model)
294            if model.is_multiplicity_model:
[a0b355b]295                self.multi_func_list.append(model)
296            else:
[cb4ef58]297                self.model_name_list.append(model.name)
[a0b355b]298
[908f090]299        # Looking for plugins
[f0d720b]300        self.stored_plugins = self.findModels()
301        self.plugins = self.stored_plugins.values()
302        for name, plug in self.stored_plugins.iteritems():
303            self.model_dictionary[name] = plug
[632fda9]304            # TODO: Remove 'hasattr' statements when old style plugin models
305            # are no longer supported. All sasmodels models will have
306            # the required attributes.
307            if hasattr(plug, 'is_structure_factor') and plug.is_structure_factor:
[908f090]308                self.struct_list.append(plug)
309                self.plugins.remove(plug)
[632fda9]310            elif hasattr(plug, 'is_form_factor') and plug.is_form_factor:
[908f090]311                self.multiplication_factor.append(plug)
[632fda9]312            if hasattr(plug, 'is_multiplicity_model') and plug.is_multiplicity_model:
[908f090]313                self.multi_func_list.append(plug)
[b1c2011]314
[f0d720b]315        self._get_multifunc_models()
316
317        return 0
318
319    def is_changed(self):
320        """
321        check the last time the plugin dir has changed and return true
[dcdca68]322        is the directory was modified else return false
[f0d720b]323        """
324        is_modified = False
325        plugin_dir = find_plugins_dir()
326        if os.path.isdir(plugin_dir):
327            temp = os.path.getmtime(plugin_dir)
[f3cfe7a]328            if  self.last_time_dir_modified < temp:
[f0d720b]329                is_modified = True
330                self.last_time_dir_modified = temp
331
332        return is_modified
333
334    def update(self):
335        """
336        return a dictionary of model if
337        new models were added else return empty dictionary
338        """
[49c46da]339        self.plugins = []
[f0d720b]340        new_plugins = self.findModels()
[ed28d8c]341        if new_plugins:
342            for name, plug in  new_plugins.items():
[1bcc48d]343                self.stored_plugins[name] = plug
344                self.plugins.append(plug)
345                self.model_dictionary[name] = plug
[05b8e01]346            self.model_combobox.set_list(CUSTOM_MODEL, self.plugins)
[f0d720b]347            return self.model_combobox.get_list()
348        else:
349            return {}
350
[f66d9d1]351    def plugins_reset(self):
[f0d720b]352        """
353        return a dictionary of model
354        """
355        self.plugins = []
[33dc18f]356        self.stored_plugins = _find_models()
[908f090]357        structure_names = [model.name for model in self.struct_list]
358        form_names = [model.name for model in self.multiplication_factor]
[33dc18f]359
360        # Remove all plugin structure factors and form factors
361        for name in copy(structure_names):
362            if '[plug-in]' in name:
363                i = structure_names.index(name)
364                del self.struct_list[i]
365                structure_names.remove(name)
366        for name in copy(form_names):
367            if '[plug-in]' in name:
368                i = form_names.index(name)
369                del self.multiplication_factor[i]
370                form_names.remove(name)
371
372        # Add new plugin structure factors and form factors
373        for name, plug in self.stored_plugins.iteritems():
[908f090]374            if plug.is_structure_factor:
375                if name in structure_names:
376                    # Delete the old model from self.struct list
377                    i = structure_names.index(name)
378                    del self.struct_list[i]
379                # Add the new model to self.struct_list
380                self.struct_list.append(plug)
381            elif plug.is_form_factor:
382                if name in form_names:
383                    # Delete the old model from self.multiplication_factor
384                    i = form_names.index(name)
385                    del self.multiplication_factor[i]
386                # Add the new model to self.multiplication_factor
387                self.multiplication_factor.append(plug)
388
389            # Add references to the updated model
[f0d720b]390            self.stored_plugins[name] = plug
[908f090]391            if not plug.is_structure_factor:
392                # Don't show S(Q) models in the 'Plugin Models' dropdown
393                self.plugins.append(plug)
[f0d720b]394            self.model_dictionary[name] = plug
395
[e92a352]396        self.model_combobox.reset_list("Plugin Models", self.plugins)
[33dc18f]397        self.model_combobox.reset_list("Structure Factors", self.struct_list)
398        self.model_combobox.reset_list("P(Q)*S(Q)", self.multiplication_factor)
399
[f0d720b]400        return self.model_combobox.get_list()
401
402    def _on_model(self, evt):
403        """
404        React to a model menu event
405
406        :param event: wx menu event
407
408        """
409        if int(evt.GetId()) in self.form_factor_dict.keys():
[313c5c9]410            from sasmodels.sasview_model import MultiplicationModel
[f0d720b]411            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
412            model1, model2 = self.form_factor_dict[int(evt.GetId())]
413            model = MultiplicationModel(model1, model2)
414        else:
415            model = self.struct_factor_dict[str(evt.GetId())]()
416
417
418    def _get_multifunc_models(self):
419        """
420        Get the multifunctional models
421        """
[cb4ef58]422        items = [item for item in self.plugins if item.is_multiplicity_model]
423        self.multi_func_list = items
[f0d720b]424
425    def get_model_list(self):
426        """
427        return dictionary of models for fitpanel use
428
429        """
430        ## Model_list now only contains attribute lists not category list.
431        ## Eventually this should be in one master list -- read in category
432        ## list then pull those models that exist and get attributes then add
433        ## to list ..and if model does not exist remove from list as now
434        ## and update json file.
435        ##
436        ## -PDB   April 26, 2014
437
438#        self.model_combobox.set_list("Shapes", self.shape_list)
439#        self.model_combobox.set_list("Shape-Independent",
440#                                     self.shape_indep_list)
441        self.model_combobox.set_list("Structure Factors", self.struct_list)
[e92a352]442        self.model_combobox.set_list("Plugin Models", self.plugins)
[f0d720b]443        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
444        self.model_combobox.set_list("multiplication",
445                                     self.multiplication_factor)
446        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
447        return self.model_combobox.get_list()
448
449    def get_model_name_list(self):
450        """
451        return regular model name list
452        """
453        return self.model_name_list
454
455    def get_model_dictionary(self):
456        """
457        return dictionary linking model names to objects
458        """
459        return self.model_dictionary
460
461
462class ModelManager(object):
463    """
464    implement model
465    """
466    __modelmanager = ModelManagerBase()
[212bfc2]467    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
[f0d720b]468                      in __modelmanager.model_dictionary.keys() \
469                      if model_name not in __modelmanager.stored_plugins.keys()]
470
471    CategoryInstaller.check_install(model_list=cat_model_list)
472    def findModels(self):
473        return self.__modelmanager.findModels()
474
475    def _getModelList(self):
476        return self.__modelmanager._getModelList()
477
478    def is_changed(self):
479        return self.__modelmanager.is_changed()
480
481    def update(self):
482        return self.__modelmanager.update()
483
[f66d9d1]484    def plugins_reset(self):
485        return self.__modelmanager.plugins_reset()
[f0d720b]486
487    def populate_menu(self, modelmenu, event_owner):
488        return self.__modelmanager.populate_menu(modelmenu, event_owner)
489
490    def _on_model(self, evt):
491        return self.__modelmanager._on_model(evt)
492
493    def _get_multifunc_models(self):
494        return self.__modelmanager._get_multifunc_models()
495
496    def get_model_list(self):
497        return self.__modelmanager.get_model_list()
498
499    def get_model_name_list(self):
500        return self.__modelmanager.get_model_name_list()
501
502    def get_model_dictionary(self):
503        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.