source: sasview/src/sas/sascalc/fit/models.py @ 8225f33

unittest-saveload
Last change on this file since 8225f33 was b963b20, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

pull config out of sas.sasgui so it can be used without reference to wx

  • Property mode set to 100644
File size: 10.9 KB
RevLine 
[f0d720b]1"""
2    Utilities to manage models
3"""
[a1b8fee]4from __future__ import print_function
5
[f0d720b]6import os
7import sys
8import time
[7673ecd]9import datetime
[f0d720b]10import logging
[d3b0c77]11import traceback
[f0d720b]12import py_compile
13import shutil
[65f3930]14
15from sasmodels.sasview_model import load_custom_model, load_standard_models
[d3b0c77]16
[b963b20]17from sas import get_user_dir
[65f3930]18
[f0d720b]19# Explicitly import from the pluginmodel module so that py2exe
20# places it in the distribution. The Model1DPlugin class is used
21# as the base class of plug-in models.
[65f3930]22from .pluginmodel import Model1DPlugin
23
[463e7ffc]24logger = logging.getLogger(__name__)
[c155a16]25
[f66d9d1]26
[f0d720b]27PLUGIN_DIR = 'plugin_models'
[d3b0c77]28PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log")
[0de74af]29PLUGIN_NAME_BASE = '[plug-in] '
[f0d720b]30
31
[7673ecd]32def plugin_log(message):
[f0d720b]33    """
[dcdca68]34    Log a message in a file located in the user's home directory
[f0d720b]35    """
[7673ecd]36    out = open(PLUGIN_LOG, 'a')
37    now = time.time()
38    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
39    out.write("%s: %s\n" % (stamp, message))
[f0d720b]40    out.close()
41
42
43def _check_plugin(model, name):
44    """
45    Do some checking before model adding plugins in the list
46
47    :param model: class model to add into the plugin list
48    :param name:name of the module plugin
49
50    :return model: model if valid model or None if not valid
51
52    """
53    #Check if the plugin is of type Model1DPlugin
54    if not issubclass(model, Model1DPlugin):
55        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
[7673ecd]56        plugin_log(msg)
[f0d720b]57        return None
58    if model.__name__ != "Model":
59        msg = "Plugin %s class name must be Model \n" % str(name)
[7673ecd]60        plugin_log(msg)
[f0d720b]61        return None
62    try:
63        new_instance = model()
[ba8d326]64    except Exception:
[f0d720b]65        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
66                                                             str(sys.exc_type),
67                                                             sys.exc_info()[1])
[7673ecd]68        plugin_log(msg)
[f0d720b]69        return None
70
71    if hasattr(new_instance, "function"):
72        try:
73            value = new_instance.function()
[ba8d326]74        except Exception:
[f0d720b]75            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
76                    (str(name), str(sys.exc_type), sys.exc_info()[1])
[7673ecd]77            plugin_log(msg)
[f0d720b]78            return None
79    else:
80        msg = "Plugin  %s needs a method called function \n" % str(name)
[7673ecd]81        plugin_log(msg)
[f0d720b]82        return None
83    return model
84
85
86def find_plugins_dir():
87    """
[dcdca68]88    Find path of the plugins directory.
89    The plugin directory is located in the user's home directory.
[f0d720b]90    """
[277257f]91    path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
[f0d720b]92
[9706d88]93    # TODO: trigger initialization of plugins dir from installer or startup
[f0d720b]94    # If the plugin directory doesn't exist, create it
[277257f]95    if not os.path.isdir(path):
96        os.makedirs(path)
97    # TODO: should we be checking for new default models every time?
[9706d88]98    # TODO: restore support for default plugins
99    #initialize_plugins_dir(path)
[277257f]100    return path
101
102
103def initialize_plugins_dir(path):
104    # TODO: There are no default plugins
[9706d88]105    # TODO: Default plugins directory is in sasgui, but models.py is in sascalc
[277257f]106    # TODO: Move default plugins beside sample data files
107    # TODO: Should not look for defaults above the root of the sasview install
108
109    # Walk up the tree looking for default plugin_models directory
110    base = os.path.abspath(os.path.dirname(__file__))
111    for _ in range(12):
112        default_plugins_path = os.path.join(base, PLUGIN_DIR)
113        if os.path.isdir(default_plugins_path):
114            break
115        base, _ = os.path.split(base)
116    else:
117        logger.error("default plugins directory not found")
118        return
119
120    # Copy files from default plugins to the .sasview directory
121    # This may include c files, depending on the example.
122    # Note: files are never replaced, even if the default plugins are updated
123    for filename in os.listdir(default_plugins_path):
124        # skip __init__.py and all pyc files
125        if filename == "__init__.py" or filename.endswith('.pyc'):
126            continue
127        source = os.path.join(default_plugins_path, filename)
128        target = os.path.join(path, filename)
129        if os.path.isfile(source) and not os.path.isfile(target):
130            shutil.copy(source, target)
131
132
133class ReportProblem(object):
[f0d720b]134    """
[dcdca68]135    Class to check for problems with specific values
[f0d720b]136    """
137    def __nonzero__(self):
[7673ecd]138        type, value, tb = sys.exc_info()
[f0d720b]139        if type is not None and issubclass(type, py_compile.PyCompileError):
[9c3d784]140            print("Problem with", repr(value))
[7673ecd]141            raise type, value, tb
[f0d720b]142        return 1
143
144report_problem = ReportProblem()
145
146
147def compile_file(dir):
148    """
149    Compile a py file
150    """
151    try:
152        import compileall
[d85b0c7]153        compileall.compile_dir(dir=dir, ddir=dir, force=0,
[f0d720b]154                               quiet=report_problem)
[ba8d326]155    except Exception:
[f0d720b]156        return sys.exc_info()[1]
157    return None
158
159
[277257f]160def find_plugin_models():
[7673ecd]161    """
162    Find custom models
163    """
[f0d720b]164    # List of plugin objects
[277257f]165    plugins_dir = find_plugins_dir()
[f0d720b]166    # Go through files in plug-in directory
[277257f]167    if not os.path.isdir(plugins_dir):
168        msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir
[c155a16]169        logger.warning(msg)
[7673ecd]170        return {}
171
[277257f]172    plugin_log("looking for models in: %s" % plugins_dir)
173    # compile_file(plugins_dir)  #always recompile the folder plugin
174    logger.info("plugin model dir: %s", plugins_dir)
[7673ecd]175
176    plugins = {}
[277257f]177    for filename in os.listdir(plugins_dir):
[7673ecd]178        name, ext = os.path.splitext(filename)
179        if ext == '.py' and not name == '__init__':
[277257f]180            path = os.path.abspath(os.path.join(plugins_dir, filename))
[7673ecd]181            try:
182                model = load_custom_model(path)
[277257f]183                # TODO: add [plug-in] tag to model name in sasview_model
184                if not model.name.startswith(PLUGIN_NAME_BASE):
185                    model.name = PLUGIN_NAME_BASE + model.name
[6fb559d]186                plugins[model.name] = model
[7673ecd]187            except Exception:
188                msg = traceback.format_exc()
189                msg += "\nwhile accessing model in %r" % path
190                plugin_log(msg)
[ba8d326]191                logger.warning("Failed to load plugin %r. See %s for details",
192                               path, PLUGIN_LOG)
[8d891d1]193
[f0d720b]194    return plugins
195
196
[ba8d326]197class ModelManagerBase(object):
[f0d720b]198    """
[dcdca68]199    Base class for the model manager
[f0d720b]200    """
[277257f]201    #: mutable dictionary of models, continually updated to reflect the
202    #: current set of plugins
203    model_dictionary = None  # type: Dict[str, Model]
204    #: constant list of standard models
205    standard_models = None  # type: Dict[str, Model]
206    #: list of plugin models reset each time the plugin directory is queried
207    plugin_models = None  # type: Dict[str, Model]
208    #: timestamp on the plugin directory at the last plugin update
209    last_time_dir_modified = 0  # type: int
[f0d720b]210
211    def __init__(self):
[277257f]212        # the model dictionary is allocated at the start and updated to
213        # reflect the current list of models.  Be sure to clear it rather
214        # than reassign to it.
[f0d720b]215        self.model_dictionary = {}
216
217        #Build list automagically from sasmodels package
[277257f]218        self.standard_models = {model.name: model
219                                for model in load_standard_models()}
220        # Look for plugins
221        self.plugins_reset()
[f0d720b]222
[277257f]223    def _is_plugin_dir_changed(self):
[f0d720b]224        """
225        check the last time the plugin dir has changed and return true
[dcdca68]226        is the directory was modified else return false
[f0d720b]227        """
228        is_modified = False
229        plugin_dir = find_plugins_dir()
230        if os.path.isdir(plugin_dir):
[277257f]231            mod_time = os.path.getmtime(plugin_dir)
232            if  self.last_time_dir_modified != mod_time:
[f0d720b]233                is_modified = True
[277257f]234                self.last_time_dir_modified = mod_time
[f0d720b]235
236        return is_modified
237
[277257f]238    def composable_models(self):
239        """
240        return list of standard models that can be used in sum/multiply
241        """
242        # TODO: should scan plugin models in addition to standard models
243        # and update model_editor so that it doesn't add plugins to the list
244        return [model.name for model in self.standard_models.values()
245                if not model.is_multiplicity_model]
246
247    def plugins_update(self):
[f0d720b]248        """
249        return a dictionary of model if
250        new models were added else return empty dictionary
251        """
[277257f]252        return self.plugins_reset()
253        #if self._is_plugin_dir_changed():
254        #    return self.plugins_reset()
255        #else:
256        #    return {}
[f0d720b]257
[f66d9d1]258    def plugins_reset(self):
[f0d720b]259        """
260        return a dictionary of model
261        """
[277257f]262        self.plugin_models = find_plugin_models()
263        self.model_dictionary.clear()
264        self.model_dictionary.update(self.standard_models)
265        self.model_dictionary.update(self.plugin_models)
266        return self.get_model_list()
[f0d720b]267
268    def get_model_list(self):
269        """
[277257f]270        return dictionary of classified models
271
272        *Structure Factors* are the structure factor models
273        *Multi-Functions* are the multiplicity models
274        *Plugin Models* are the plugin models
[f0d720b]275
[277257f]276        Note that a model can be both a plugin and a structure factor or
277        multiplicity model.
[f0d720b]278        """
279        ## Model_list now only contains attribute lists not category list.
280        ## Eventually this should be in one master list -- read in category
281        ## list then pull those models that exist and get attributes then add
282        ## to list ..and if model does not exist remove from list as now
283        ## and update json file.
284        ##
285        ## -PDB   April 26, 2014
286
287
[277257f]288        # Classify models
289        structure_factors = []
[69363c7]290        form_factors = []
[277257f]291        multiplicity_models = []
292        for model in self.model_dictionary.values():
293            # Old style models don't have is_structure_factor attribute
294            if getattr(model, 'is_structure_factor', False):
295                structure_factors.append(model)
[69363c7]296            if getattr(model, 'is_form_factor', False):
297                form_factors.append(model)
[277257f]298            if model.is_multiplicity_model:
299                multiplicity_models.append(model)
300        plugin_models = list(self.plugin_models.values())
301
302        return {
303            "Structure Factors": structure_factors,
[69363c7]304            "Form Factors": form_factors,
[277257f]305            "Plugin Models": plugin_models,
306            "Multi-Functions": multiplicity_models,
307        }
[f0d720b]308
309
310class ModelManager(object):
311    """
[277257f]312    manage the list of available models
[f0d720b]313    """
[65f3930]314    base = None  # type: ModelManagerBase()
315
316    def __init__(self):
317        if ModelManager.base is None:
[277257f]318            ModelManager.base = ModelManagerBase()
[65f3930]319
320    def cat_model_list(self):
[277257f]321        return list(self.base.standard_models.values())
[f0d720b]322
323    def update(self):
[277257f]324        return self.base.plugins_update()
[f0d720b]325
[f66d9d1]326    def plugins_reset(self):
[65f3930]327        return self.base.plugins_reset()
[f0d720b]328
329    def get_model_list(self):
[65f3930]330        return self.base.get_model_list()
[f0d720b]331
[277257f]332    def composable_models(self):
333        return self.base.composable_models()
[f0d720b]334
335    def get_model_dictionary(self):
[277257f]336        return self.base.model_dictionary
Note: See TracBrowser for help on using the repository browser.