source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ ed03b99

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ed03b99 was 914ba0a, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

merge with master

  • Property mode set to 100644
File size: 14.0 KB
Line 
1"""
2    Utilities to manage models
3"""
4from __future__ import print_function
5
6import traceback
7import os
8import sys
9import os.path
10# Time is needed by the log method
11import time
12import datetime
13import logging
14import py_compile
15import shutil
16from sasmodels.sasview_model import load_custom_model, load_standard_models
17# Explicitly import from the pluginmodel module so that py2exe
18# places it in the distribution. The Model1DPlugin class is used
19# as the base class of plug-in models.
20from sas.sasgui import get_user_dir
21from sas.sascalc.fit.pluginmodel import Model1DPlugin
22from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
23
24logger = logging.getLogger(__name__)
25
26
27PLUGIN_DIR = 'plugin_models'
28PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log")
29PLUGIN_NAME_BASE = '[plug-in] '
30
31def get_model_python_path():
32    """
33    Returns the python path for a model
34    """
35    return os.path.dirname(__file__)
36
37
38def plugin_log(message):
39    """
40    Log a message in a file located in the user's home directory
41    """
42    out = open(PLUGIN_LOG, 'a')
43    now = time.time()
44    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
45    out.write("%s: %s\n" % (stamp, message))
46    out.close()
47
48
49def _check_plugin(model, name):
50    """
51    Do some checking before model adding plugins in the list
52
53    :param model: class model to add into the plugin list
54    :param name:name of the module plugin
55
56    :return model: model if valid model or None if not valid
57
58    """
59    #Check if the plugin is of type Model1DPlugin
60    if not issubclass(model, Model1DPlugin):
61        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
62        plugin_log(msg)
63        return None
64    if model.__name__ != "Model":
65        msg = "Plugin %s class name must be Model \n" % str(name)
66        plugin_log(msg)
67        return None
68    try:
69        new_instance = model()
70    except:
71        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
72                                                             str(sys.exc_type),
73                                                             sys.exc_info()[1])
74        plugin_log(msg)
75        return None
76
77    if hasattr(new_instance, "function"):
78        try:
79            value = new_instance.function()
80        except:
81            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
82                    (str(name), str(sys.exc_type), sys.exc_info()[1])
83            plugin_log(msg)
84            return None
85    else:
86        msg = "Plugin  %s needs a method called function \n" % str(name)
87        plugin_log(msg)
88        return None
89    return model
90
91
92def find_plugins_dir():
93    """
94    Find path of the plugins directory.
95    The plugin directory is located in the user's home directory.
96    """
97    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
98
99    # If the plugin directory doesn't exist, create it
100    if not os.path.isdir(dir):
101        os.makedirs(dir)
102
103    # Find paths needed
104    try:
105        # For source
106        if os.path.isdir(os.path.dirname(__file__)):
107            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
108        else:
109            raise
110    except:
111        # Check for data path next to exe/zip file.
112        #Look for maximum n_dir up of the current dir to find plugins dir
113        n_dir = 12
114        p_dir = None
115        f_dir = os.path.join(os.path.dirname(__file__))
116        for i in range(n_dir):
117            if i > 1:
118                f_dir, _ = os.path.split(f_dir)
119            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
120            if os.path.isdir(plugin_path):
121                p_dir = plugin_path
122                break
123        if not p_dir:
124            raise
125    # Place example user models as needed
126    if os.path.isdir(p_dir):
127        for file in os.listdir(p_dir):
128            file_path = os.path.join(p_dir, file)
129            if os.path.isfile(file_path):
130                if file.split(".")[-1] == 'py' and\
131                    file.split(".")[0] != '__init__':
132                    if not os.path.isfile(os.path.join(dir, file)):
133                        shutil.copy(file_path, dir)
134
135    return dir
136
137
138class ReportProblem:
139    """
140    Class to check for problems with specific values
141    """
142    def __nonzero__(self):
143        type, value, tb = sys.exc_info()
144        if type is not None and issubclass(type, py_compile.PyCompileError):
145            print("Problem with", repr(value))
146            raise type, value, tb
147        return 1
148
149report_problem = ReportProblem()
150
151
152def compile_file(dir):
153    """
154    Compile a py file
155    """
156    try:
157        import compileall
158        compileall.compile_dir(dir=dir, ddir=dir, force=1,
159                               quiet=report_problem)
160    except:
161        return sys.exc_info()[1]
162    return None
163
164
165def _findModels(dir):
166    """
167    Find custom models
168    """
169    # List of plugin objects
170    dir = find_plugins_dir()
171    # Go through files in plug-in directory
172    if not os.path.isdir(dir):
173        msg = "SasView couldn't locate Model plugin folder %r." % dir
174        logger.warning(msg)
175        return {}
176
177    plugin_log("looking for models in: %s" % str(dir))
178    #compile_file(dir)  #always recompile the folder plugin
179    logger.info("plugin model dir: %s" % str(dir))
180
181    plugins = {}
182    for filename in os.listdir(dir):
183        name, ext = os.path.splitext(filename)
184        if ext == '.py' and not name == '__init__':
185            path = os.path.abspath(os.path.join(dir, filename))
186            try:
187                model = load_custom_model(path)
188                model.name = PLUGIN_NAME_BASE + model.name
189                plugins[model.name] = model
190            except Exception:
191                msg = traceback.format_exc()
192                msg += "\nwhile accessing model in %r" % path
193                plugin_log(msg)
194                logger.warning("Failed to load plugin %r. See %s for details"
195                                % (path, PLUGIN_LOG))
196
197    return plugins
198
199
200class ModelList(object):
201    """
202    Contains dictionary of model and their type
203    """
204    def __init__(self):
205        """
206        """
207        self.mydict = {}
208
209    def set_list(self, name, mylist):
210        """
211        :param name: the type of the list
212        :param mylist: the list to add
213
214        """
215        if name not in self.mydict.keys():
216            self.reset_list(name, mylist)
217
218    def reset_list(self, name, mylist):
219        """
220        :param name: the type of the list
221        :param mylist: the list to add
222        """
223        self.mydict[name] = mylist
224
225    def get_list(self):
226        """
227        return all the list stored in a dictionary object
228        """
229        return self.mydict
230
231
232class ModelManagerBase:
233    """
234    Base class for the model manager
235    """
236    ## external dict for models
237    model_combobox = ModelList()
238    ## Dictionary of form factor models
239    form_factor_dict = {}
240    ## dictionary of structure factor models
241    struct_factor_dict = {}
242    ##list of structure factors
243    struct_list = []
244    ##list of model allowing multiplication by a structure factor
245    multiplication_factor = []
246    ##list of multifunctional shapes (i.e. that have user defined number of levels
247    multi_func_list = []
248    ## list of added models -- currently python models found in the plugin dir.
249    plugins = []
250    ## Event owner (guiframe)
251    event_owner = None
252    last_time_dir_modified = 0
253
254    def __init__(self):
255        self.model_dictionary = {}
256        self.stored_plugins = {}
257        self._getModelList()
258
259    def findModels(self):
260        """
261        find  plugin model in directory of plugin .recompile all file
262        in the directory if file were modified
263        """
264        temp = {}
265        if self.is_changed():
266            return  _findModels(dir)
267        logger.info("plugin model : %s" % str(temp))
268        return temp
269
270    def _getModelList(self):
271        """
272        List of models we want to make available by default
273        for this application
274
275        :return: the next free event ID following the new menu events
276
277        """
278
279        # regular model names only
280        self.model_name_list = []
281
282        #Build list automagically from sasmodels package
283        for model in load_standard_models():
284            self.model_dictionary[model.name] = model
285            if model.is_structure_factor:
286                self.struct_list.append(model)
287            if model.is_form_factor:
288                self.multiplication_factor.append(model)
289            if model.is_multiplicity_model:
290                self.multi_func_list.append(model)
291            else:
292                self.model_name_list.append(model.name)
293
294        #Looking for plugins
295        self.stored_plugins = self.findModels()
296        self.plugins = self.stored_plugins.values()
297        for name, plug in self.stored_plugins.iteritems():
298            self.model_dictionary[name] = plug
299
300        self._get_multifunc_models()
301
302        return 0
303
304    def is_changed(self):
305        """
306        check the last time the plugin dir has changed and return true
307        is the directory was modified else return false
308        """
309        is_modified = False
310        plugin_dir = find_plugins_dir()
311        if os.path.isdir(plugin_dir):
312            temp = os.path.getmtime(plugin_dir)
313            if  self.last_time_dir_modified != temp:
314                is_modified = True
315                self.last_time_dir_modified = temp
316
317        return is_modified
318
319    def update(self):
320        """
321        return a dictionary of model if
322        new models were added else return empty dictionary
323        """
324        new_plugins = self.findModels()
325        if len(new_plugins) > 0:
326            for name, plug in  new_plugins.iteritems():
327                if name not in self.stored_plugins.keys():
328                    self.stored_plugins[name] = plug
329                    self.plugins.append(plug)
330                    self.model_dictionary[name] = plug
331            self.model_combobox.set_list("Plugin Models", self.plugins)
332            return self.model_combobox.get_list()
333        else:
334            return {}
335
336    def plugins_reset(self):
337        """
338        return a dictionary of model
339        """
340        self.plugins = []
341        new_plugins = _findModels(dir)
342        for name, plug in  new_plugins.iteritems():
343            for stored_name, stored_plug in self.stored_plugins.iteritems():
344                if name == stored_name:
345                    del self.stored_plugins[name]
346                    del self.model_dictionary[name]
347                    break
348            self.stored_plugins[name] = plug
349            self.plugins.append(plug)
350            self.model_dictionary[name] = plug
351
352        self.model_combobox.reset_list("Plugin Models", self.plugins)
353        return self.model_combobox.get_list()
354
355    def _on_model(self, evt):
356        """
357        React to a model menu event
358
359        :param event: wx menu event
360
361        """
362        if int(evt.GetId()) in self.form_factor_dict.keys():
363            from sasmodels.sasview_model import MultiplicationModel
364            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
365            model1, model2 = self.form_factor_dict[int(evt.GetId())]
366            model = MultiplicationModel(model1, model2)
367        else:
368            model = self.struct_factor_dict[str(evt.GetId())]()
369
370
371    def _get_multifunc_models(self):
372        """
373        Get the multifunctional models
374        """
375        items = [item for item in self.plugins if item.is_multiplicity_model]
376        self.multi_func_list = items
377
378    def get_model_list(self):
379        """
380        return dictionary of models for fitpanel use
381
382        """
383        ## Model_list now only contains attribute lists not category list.
384        ## Eventually this should be in one master list -- read in category
385        ## list then pull those models that exist and get attributes then add
386        ## to list ..and if model does not exist remove from list as now
387        ## and update json file.
388        ##
389        ## -PDB   April 26, 2014
390
391#        self.model_combobox.set_list("Shapes", self.shape_list)
392#        self.model_combobox.set_list("Shape-Independent",
393#                                     self.shape_indep_list)
394        self.model_combobox.set_list("Structure Factors", self.struct_list)
395        self.model_combobox.set_list("Plugin Models", self.plugins)
396        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
397        self.model_combobox.set_list("multiplication",
398                                     self.multiplication_factor)
399        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
400        return self.model_combobox.get_list()
401
402    def get_model_name_list(self):
403        """
404        return regular model name list
405        """
406        return self.model_name_list
407
408    def get_model_dictionary(self):
409        """
410        return dictionary linking model names to objects
411        """
412        return self.model_dictionary
413
414
415class ModelManager(object):
416    """
417    implement model
418    """
419    __modelmanager = ModelManagerBase()
420    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
421                      in __modelmanager.model_dictionary.keys() \
422                      if model_name not in __modelmanager.stored_plugins.keys()]
423
424    CategoryInstaller.check_install(model_list=cat_model_list)
425    def findModels(self):
426        return self.__modelmanager.findModels()
427
428    def _getModelList(self):
429        return self.__modelmanager._getModelList()
430
431    def is_changed(self):
432        return self.__modelmanager.is_changed()
433
434    def update(self):
435        return self.__modelmanager.update()
436
437    def plugins_reset(self):
438        return self.__modelmanager.plugins_reset()
439
440    def populate_menu(self, modelmenu, event_owner):
441        return self.__modelmanager.populate_menu(modelmenu, event_owner)
442
443    def _on_model(self, evt):
444        return self.__modelmanager._on_model(evt)
445
446    def _get_multifunc_models(self):
447        return self.__modelmanager._get_multifunc_models()
448
449    def get_model_list(self):
450        return self.__modelmanager.get_model_list()
451
452    def get_model_name_list(self):
453        return self.__modelmanager.get_model_name_list()
454
455    def get_model_dictionary(self):
456        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.