source: sasview/src/sas/sascalc/fit/models.py @ 34d7b35

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 34d7b35 was 9706d88, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

merge fixups

  • Property mode set to 100644
File size: 11.1 KB
Line 
1"""
2    Utilities to manage models
3"""
4from __future__ import print_function
5
6import traceback
7import os
8import sys
9import os.path
10# Time is needed by the log method
11import time
12import datetime
13import logging
14import py_compile
15import shutil
16#?? from copy import copy
17
18from sasmodels.sasview_model import load_custom_model, load_standard_models
19from sasmodels.sasview_model import MultiplicationModel
20#?? from sas.sasgui.perspectives.fitting.fitpage import CUSTOM_MODEL
21
22# Explicitly import from the pluginmodel module so that py2exe
23# places it in the distribution. The Model1DPlugin class is used
24# as the base class of plug-in models.
25from .pluginmodel import Model1DPlugin
26
27logger = logging.getLogger(__name__)
28
29
30PLUGIN_DIR = 'plugin_models'
31PLUGIN_LOG = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR,
32                          "plugins.log")
33PLUGIN_NAME_BASE = '[plug-in] '
34
35
36def plugin_log(message):
37    """
38    Log a message in a file located in the user's home directory
39    """
40    out = open(PLUGIN_LOG, 'a')
41    now = time.time()
42    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
43    out.write("%s: %s\n" % (stamp, message))
44    out.close()
45
46
47def _check_plugin(model, name):
48    """
49    Do some checking before model adding plugins in the list
50
51    :param model: class model to add into the plugin list
52    :param name:name of the module plugin
53
54    :return model: model if valid model or None if not valid
55
56    """
57    #Check if the plugin is of type Model1DPlugin
58    if not issubclass(model, Model1DPlugin):
59        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
60        plugin_log(msg)
61        return None
62    if model.__name__ != "Model":
63        msg = "Plugin %s class name must be Model \n" % str(name)
64        plugin_log(msg)
65        return None
66    try:
67        new_instance = model()
68    except Exception:
69        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
70                                                             str(sys.exc_type),
71                                                             sys.exc_info()[1])
72        plugin_log(msg)
73        return None
74
75    if hasattr(new_instance, "function"):
76        try:
77            value = new_instance.function()
78        except Exception:
79            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
80                    (str(name), str(sys.exc_type), sys.exc_info()[1])
81            plugin_log(msg)
82            return None
83    else:
84        msg = "Plugin  %s needs a method called function \n" % str(name)
85        plugin_log(msg)
86        return None
87    return model
88
89
90def find_plugins_dir():
91    """
92    Find path of the plugins directory.
93    The plugin directory is located in the user's home directory.
94    """
95    path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
96
97    # TODO: trigger initialization of plugins dir from installer or startup
98    # If the plugin directory doesn't exist, create it
99    if not os.path.isdir(path):
100        os.makedirs(path)
101    # TODO: should we be checking for new default models every time?
102    # TODO: restore support for default plugins
103    #initialize_plugins_dir(path)
104    return path
105
106
107def initialize_plugins_dir(path):
108    # TODO: There are no default plugins
109    # TODO: Default plugins directory is in sasgui, but models.py is in sascalc
110    # TODO: Move default plugins beside sample data files
111    # TODO: Should not look for defaults above the root of the sasview install
112
113    # Walk up the tree looking for default plugin_models directory
114    base = os.path.abspath(os.path.dirname(__file__))
115    for _ in range(12):
116        default_plugins_path = os.path.join(base, PLUGIN_DIR)
117        if os.path.isdir(default_plugins_path):
118            break
119        base, _ = os.path.split(base)
120    else:
121        logger.error("default plugins directory not found")
122        return
123
124    # Copy files from default plugins to the .sasview directory
125    # This may include c files, depending on the example.
126    # Note: files are never replaced, even if the default plugins are updated
127    for filename in os.listdir(default_plugins_path):
128        # skip __init__.py and all pyc files
129        if filename == "__init__.py" or filename.endswith('.pyc'):
130            continue
131        source = os.path.join(default_plugins_path, filename)
132        target = os.path.join(path, filename)
133        if os.path.isfile(source) and not os.path.isfile(target):
134            shutil.copy(source, target)
135
136
137class ReportProblem(object):
138    """
139    Class to check for problems with specific values
140    """
141    def __nonzero__(self):
142        type, value, tb = sys.exc_info()
143        if type is not None and issubclass(type, py_compile.PyCompileError):
144            print("Problem with", repr(value))
145            raise type, value, tb
146        return 1
147
148report_problem = ReportProblem()
149
150
151def compile_file(dir):
152    """
153    Compile a py file
154    """
155    try:
156        import compileall
157        compileall.compile_dir(dir=dir, ddir=dir, force=0,
158                               quiet=report_problem)
159    except Exception:
160        return sys.exc_info()[1]
161    return None
162
163
164def find_plugin_models():
165    """
166    Find custom models
167    """
168    # List of plugin objects
169    plugins_dir = find_plugins_dir()
170    # Go through files in plug-in directory
171    if not os.path.isdir(plugins_dir):
172        msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir
173        logger.warning(msg)
174        return {}
175
176    plugin_log("looking for models in: %s" % plugins_dir)
177    # compile_file(plugins_dir)  #always recompile the folder plugin
178    logger.info("plugin model dir: %s", plugins_dir)
179
180    plugins = {}
181    for filename in os.listdir(plugins_dir):
182        name, ext = os.path.splitext(filename)
183        if ext == '.py' and not name == '__init__':
184            path = os.path.abspath(os.path.join(plugins_dir, filename))
185            try:
186                model = load_custom_model(path)
187                # TODO: add [plug-in] tag to model name in sasview_model
188                if not model.name.startswith(PLUGIN_NAME_BASE):
189                    model.name = PLUGIN_NAME_BASE + model.name
190                plugins[model.name] = model
191            except Exception:
192                msg = traceback.format_exc()
193                msg += "\nwhile accessing model in %r" % path
194                plugin_log(msg)
195                logger.warning("Failed to load plugin %r. See %s for details",
196                               path, PLUGIN_LOG)
197
198    return plugins
199
200
201class ModelManagerBase(object):
202    """
203    Base class for the model manager
204    """
205    #: mutable dictionary of models, continually updated to reflect the
206    #: current set of plugins
207    model_dictionary = None  # type: Dict[str, Model]
208    #: constant list of standard models
209    standard_models = None  # type: Dict[str, Model]
210    #: list of plugin models reset each time the plugin directory is queried
211    plugin_models = None  # type: Dict[str, Model]
212    #: timestamp on the plugin directory at the last plugin update
213    last_time_dir_modified = 0  # type: int
214
215    def __init__(self):
216        # the model dictionary is allocated at the start and updated to
217        # reflect the current list of models.  Be sure to clear it rather
218        # than reassign to it.
219        self.model_dictionary = {}
220
221        #Build list automagically from sasmodels package
222        self.standard_models = {model.name: model
223                                for model in load_standard_models()}
224        # Look for plugins
225        self.plugins_reset()
226
227    def _is_plugin_dir_changed(self):
228        """
229        check the last time the plugin dir has changed and return true
230        is the directory was modified else return false
231        """
232        is_modified = False
233        plugin_dir = find_plugins_dir()
234        if os.path.isdir(plugin_dir):
235            mod_time = os.path.getmtime(plugin_dir)
236            if  self.last_time_dir_modified != mod_time:
237                is_modified = True
238                self.last_time_dir_modified = mod_time
239
240        return is_modified
241
242    def composable_models(self):
243        """
244        return list of standard models that can be used in sum/multiply
245        """
246        # TODO: should scan plugin models in addition to standard models
247        # and update model_editor so that it doesn't add plugins to the list
248        return [model.name for model in self.standard_models.values()
249                if not model.is_multiplicity_model]
250
251    def plugins_update(self):
252        """
253        return a dictionary of model if
254        new models were added else return empty dictionary
255        """
256        return self.plugins_reset()
257        #if self._is_plugin_dir_changed():
258        #    return self.plugins_reset()
259        #else:
260        #    return {}
261
262    def plugins_reset(self):
263        """
264        return a dictionary of model
265        """
266        self.plugin_models = find_plugin_models()
267        self.model_dictionary.clear()
268        self.model_dictionary.update(self.standard_models)
269        self.model_dictionary.update(self.plugin_models)
270        return self.get_model_list()
271
272    def get_model_list(self):
273        """
274        return dictionary of classified models
275
276        *Structure Factors* are the structure factor models
277        *Multi-Functions* are the multiplicity models
278        *Plugin Models* are the plugin models
279
280        Note that a model can be both a plugin and a structure factor or
281        multiplicity model.
282        """
283        ## Model_list now only contains attribute lists not category list.
284        ## Eventually this should be in one master list -- read in category
285        ## list then pull those models that exist and get attributes then add
286        ## to list ..and if model does not exist remove from list as now
287        ## and update json file.
288        ##
289        ## -PDB   April 26, 2014
290
291
292        # Classify models
293        structure_factors = []
294        form_factors = []
295        multiplicity_models = []
296        for model in self.model_dictionary.values():
297            # Old style models don't have is_structure_factor attribute
298            if getattr(model, 'is_structure_factor', False):
299                structure_factors.append(model)
300            if getattr(model, 'is_form_factor', False):
301                form_factors.append(model)
302            if model.is_multiplicity_model:
303                multiplicity_models.append(model)
304        plugin_models = list(self.plugin_models.values())
305
306        return {
307            "Structure Factors": structure_factors,
308            "Form Factors": form_factors,
309            "Plugin Models": plugin_models,
310            "Multi-Functions": multiplicity_models,
311        }
312
313
314class ModelManager(object):
315    """
316    manage the list of available models
317    """
318    base = None  # type: ModelManagerBase()
319
320    def __init__(self):
321        if ModelManager.base is None:
322            ModelManager.base = ModelManagerBase()
323
324    def cat_model_list(self):
325        return list(self.base.standard_models.values())
326
327    def update(self):
328        return self.base.plugins_update()
329
330    def plugins_reset(self):
331        return self.base.plugins_reset()
332
333    def get_model_list(self):
334        return self.base.get_model_list()
335
336    def composable_models(self):
337        return self.base.composable_models()
338
339    def get_model_dictionary(self):
340        return self.base.model_dictionary
Note: See TracBrowser for help on using the repository browser.