source: sasview/src/sas/sascalc/fit/models.py @ dbfd307

magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1249
Last change on this file since dbfd307 was e090ba90, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

remove errors and warnings from py37 tests of sascalc

  • Property mode set to 100644
File size: 10.8 KB
RevLine 
[f0d720b]1"""
2    Utilities to manage models
3"""
[a1b8fee]4from __future__ import print_function
5
[f0d720b]6import os
7import sys
8import time
[7673ecd]9import datetime
[f0d720b]10import logging
[d3b0c77]11import traceback
[f0d720b]12import py_compile
13import shutil
[65f3930]14
[e090ba90]15from six import reraise
16
[65f3930]17from sasmodels.sasview_model import load_custom_model, load_standard_models
[d3b0c77]18
[b963b20]19from sas import get_user_dir
[65f3930]20
[f0d720b]21# Explicitly import from the pluginmodel module so that py2exe
22# places it in the distribution. The Model1DPlugin class is used
23# as the base class of plug-in models.
[65f3930]24from .pluginmodel import Model1DPlugin
25
[463e7ffc]26logger = logging.getLogger(__name__)
[c155a16]27
[f66d9d1]28
[f0d720b]29PLUGIN_DIR = 'plugin_models'
[d3b0c77]30PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log")
[0de74af]31PLUGIN_NAME_BASE = '[plug-in] '
[f0d720b]32
33
[7673ecd]34def plugin_log(message):
[f0d720b]35    """
[dcdca68]36    Log a message in a file located in the user's home directory
[f0d720b]37    """
[7673ecd]38    out = open(PLUGIN_LOG, 'a')
39    now = time.time()
40    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
41    out.write("%s: %s\n" % (stamp, message))
[f0d720b]42    out.close()
43
44
45def _check_plugin(model, name):
46    """
47    Do some checking before model adding plugins in the list
48
49    :param model: class model to add into the plugin list
50    :param name:name of the module plugin
51
52    :return model: model if valid model or None if not valid
53
54    """
55    #Check if the plugin is of type Model1DPlugin
56    if not issubclass(model, Model1DPlugin):
57        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
[7673ecd]58        plugin_log(msg)
[f0d720b]59        return None
60    if model.__name__ != "Model":
61        msg = "Plugin %s class name must be Model \n" % str(name)
[7673ecd]62        plugin_log(msg)
[f0d720b]63        return None
64    try:
65        new_instance = model()
[e090ba90]66    except Exception as exc:
67        msg = ("Plugin %s error in __init__ \n\t: %s %s\n"
68               % (name, type(exc), exc))
[7673ecd]69        plugin_log(msg)
[f0d720b]70        return None
71
72    if hasattr(new_instance, "function"):
73        try:
74            value = new_instance.function()
[e090ba90]75        except Exception as exc:
[f0d720b]76            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
[e090ba90]77                    (str(name), str(type(exc)), exc)
[7673ecd]78            plugin_log(msg)
[f0d720b]79            return None
80    else:
81        msg = "Plugin  %s needs a method called function \n" % str(name)
[7673ecd]82        plugin_log(msg)
[f0d720b]83        return None
84    return model
85
86
87def find_plugins_dir():
88    """
[dcdca68]89    Find path of the plugins directory.
90    The plugin directory is located in the user's home directory.
[f0d720b]91    """
[277257f]92    path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
[f0d720b]93
[9706d88]94    # TODO: trigger initialization of plugins dir from installer or startup
[f0d720b]95    # If the plugin directory doesn't exist, create it
[277257f]96    if not os.path.isdir(path):
97        os.makedirs(path)
98    # TODO: should we be checking for new default models every time?
[9706d88]99    # TODO: restore support for default plugins
100    #initialize_plugins_dir(path)
[277257f]101    return path
102
103
104def initialize_plugins_dir(path):
105    # TODO: There are no default plugins
[9706d88]106    # TODO: Default plugins directory is in sasgui, but models.py is in sascalc
[277257f]107    # TODO: Move default plugins beside sample data files
108    # TODO: Should not look for defaults above the root of the sasview install
109
110    # Walk up the tree looking for default plugin_models directory
111    base = os.path.abspath(os.path.dirname(__file__))
112    for _ in range(12):
113        default_plugins_path = os.path.join(base, PLUGIN_DIR)
114        if os.path.isdir(default_plugins_path):
115            break
116        base, _ = os.path.split(base)
117    else:
118        logger.error("default plugins directory not found")
119        return
120
121    # Copy files from default plugins to the .sasview directory
122    # This may include c files, depending on the example.
123    # Note: files are never replaced, even if the default plugins are updated
124    for filename in os.listdir(default_plugins_path):
125        # skip __init__.py and all pyc files
126        if filename == "__init__.py" or filename.endswith('.pyc'):
127            continue
128        source = os.path.join(default_plugins_path, filename)
129        target = os.path.join(path, filename)
130        if os.path.isfile(source) and not os.path.isfile(target):
131            shutil.copy(source, target)
132
133
134class ReportProblem(object):
[f0d720b]135    """
[dcdca68]136    Class to check for problems with specific values
[f0d720b]137    """
138    def __nonzero__(self):
[7673ecd]139        type, value, tb = sys.exc_info()
[f0d720b]140        if type is not None and issubclass(type, py_compile.PyCompileError):
[9c3d784]141            print("Problem with", repr(value))
[e090ba90]142            reraise(type, value, tb)
[f0d720b]143        return 1
144
145report_problem = ReportProblem()
146
147
148def compile_file(dir):
149    """
150    Compile a py file
151    """
152    try:
153        import compileall
[d85b0c7]154        compileall.compile_dir(dir=dir, ddir=dir, force=0,
[f0d720b]155                               quiet=report_problem)
[e090ba90]156    except Exception as exc:
157        return exc
[f0d720b]158    return None
159
160
[277257f]161def find_plugin_models():
[7673ecd]162    """
163    Find custom models
164    """
[f0d720b]165    # List of plugin objects
[277257f]166    plugins_dir = find_plugins_dir()
[f0d720b]167    # Go through files in plug-in directory
[277257f]168    if not os.path.isdir(plugins_dir):
169        msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir
[c155a16]170        logger.warning(msg)
[7673ecd]171        return {}
172
[277257f]173    plugin_log("looking for models in: %s" % plugins_dir)
174    # compile_file(plugins_dir)  #always recompile the folder plugin
175    logger.info("plugin model dir: %s", plugins_dir)
[7673ecd]176
177    plugins = {}
[277257f]178    for filename in os.listdir(plugins_dir):
[7673ecd]179        name, ext = os.path.splitext(filename)
180        if ext == '.py' and not name == '__init__':
[277257f]181            path = os.path.abspath(os.path.join(plugins_dir, filename))
[7673ecd]182            try:
183                model = load_custom_model(path)
[277257f]184                # TODO: add [plug-in] tag to model name in sasview_model
185                if not model.name.startswith(PLUGIN_NAME_BASE):
186                    model.name = PLUGIN_NAME_BASE + model.name
[6fb559d]187                plugins[model.name] = model
[e090ba90]188            except Exception as exc:
[7673ecd]189                msg = traceback.format_exc()
190                msg += "\nwhile accessing model in %r" % path
191                plugin_log(msg)
[ba8d326]192                logger.warning("Failed to load plugin %r. See %s for details",
193                               path, PLUGIN_LOG)
[8d891d1]194
[f0d720b]195    return plugins
196
197
[ba8d326]198class ModelManagerBase(object):
[f0d720b]199    """
[dcdca68]200    Base class for the model manager
[f0d720b]201    """
[277257f]202    #: mutable dictionary of models, continually updated to reflect the
203    #: current set of plugins
204    model_dictionary = None  # type: Dict[str, Model]
205    #: constant list of standard models
206    standard_models = None  # type: Dict[str, Model]
207    #: list of plugin models reset each time the plugin directory is queried
208    plugin_models = None  # type: Dict[str, Model]
209    #: timestamp on the plugin directory at the last plugin update
210    last_time_dir_modified = 0  # type: int
[f0d720b]211
212    def __init__(self):
[277257f]213        # the model dictionary is allocated at the start and updated to
214        # reflect the current list of models.  Be sure to clear it rather
215        # than reassign to it.
[f0d720b]216        self.model_dictionary = {}
217
218        #Build list automagically from sasmodels package
[277257f]219        self.standard_models = {model.name: model
220                                for model in load_standard_models()}
221        # Look for plugins
222        self.plugins_reset()
[f0d720b]223
[277257f]224    def _is_plugin_dir_changed(self):
[f0d720b]225        """
226        check the last time the plugin dir has changed and return true
[dcdca68]227        is the directory was modified else return false
[f0d720b]228        """
229        is_modified = False
230        plugin_dir = find_plugins_dir()
231        if os.path.isdir(plugin_dir):
[277257f]232            mod_time = os.path.getmtime(plugin_dir)
233            if  self.last_time_dir_modified != mod_time:
[f0d720b]234                is_modified = True
[277257f]235                self.last_time_dir_modified = mod_time
[f0d720b]236
237        return is_modified
238
[277257f]239    def composable_models(self):
240        """
241        return list of standard models that can be used in sum/multiply
242        """
243        # TODO: should scan plugin models in addition to standard models
244        # and update model_editor so that it doesn't add plugins to the list
245        return [model.name for model in self.standard_models.values()
246                if not model.is_multiplicity_model]
247
248    def plugins_update(self):
[f0d720b]249        """
250        return a dictionary of model if
251        new models were added else return empty dictionary
252        """
[277257f]253        return self.plugins_reset()
254        #if self._is_plugin_dir_changed():
255        #    return self.plugins_reset()
256        #else:
257        #    return {}
[f0d720b]258
[f66d9d1]259    def plugins_reset(self):
[f0d720b]260        """
261        return a dictionary of model
262        """
[277257f]263        self.plugin_models = find_plugin_models()
264        self.model_dictionary.clear()
265        self.model_dictionary.update(self.standard_models)
266        self.model_dictionary.update(self.plugin_models)
267        return self.get_model_list()
[f0d720b]268
269    def get_model_list(self):
270        """
[277257f]271        return dictionary of classified models
272
273        *Structure Factors* are the structure factor models
274        *Multi-Functions* are the multiplicity models
275        *Plugin Models* are the plugin models
[f0d720b]276
[277257f]277        Note that a model can be both a plugin and a structure factor or
278        multiplicity model.
[f0d720b]279        """
280        ## Model_list now only contains attribute lists not category list.
281        ## Eventually this should be in one master list -- read in category
282        ## list then pull those models that exist and get attributes then add
283        ## to list ..and if model does not exist remove from list as now
284        ## and update json file.
285        ##
286        ## -PDB   April 26, 2014
287
288
[277257f]289        # Classify models
290        structure_factors = []
[69363c7]291        form_factors = []
[277257f]292        multiplicity_models = []
293        for model in self.model_dictionary.values():
294            # Old style models don't have is_structure_factor attribute
295            if getattr(model, 'is_structure_factor', False):
296                structure_factors.append(model)
[69363c7]297            if getattr(model, 'is_form_factor', False):
298                form_factors.append(model)
[277257f]299            if model.is_multiplicity_model:
300                multiplicity_models.append(model)
301        plugin_models = list(self.plugin_models.values())
302
303        return {
304            "Structure Factors": structure_factors,
[69363c7]305            "Form Factors": form_factors,
[277257f]306            "Plugin Models": plugin_models,
307            "Multi-Functions": multiplicity_models,
308        }
[f0d720b]309
310
311class ModelManager(object):
312    """
[277257f]313    manage the list of available models
[f0d720b]314    """
[65f3930]315    base = None  # type: ModelManagerBase()
316
317    def __init__(self):
318        if ModelManager.base is None:
[277257f]319            ModelManager.base = ModelManagerBase()
[65f3930]320
321    def cat_model_list(self):
[277257f]322        return list(self.base.standard_models.values())
[f0d720b]323
324    def update(self):
[277257f]325        return self.base.plugins_update()
[f0d720b]326
[f66d9d1]327    def plugins_reset(self):
[65f3930]328        return self.base.plugins_reset()
[f0d720b]329
330    def get_model_list(self):
[65f3930]331        return self.base.get_model_list()
[f0d720b]332
[277257f]333    def composable_models(self):
334        return self.base.composable_models()
[f0d720b]335
336    def get_model_dictionary(self):
[277257f]337        return self.base.model_dictionary
Note: See TracBrowser for help on using the repository browser.