source: sasview/src/sas/qtgui/Perspectives/Fitting/ModelUtilities.py @ 6b43c58

Last change on this file since 6b43c58 was 125c4be, checked in by Piotr Rozyczko <rozyczko@…>, 8 years ago

Allow rebuilding of the categories file and plugin_models dir

  • Property mode set to 100755
File size: 14.1 KB
Line 
1"""
2    Utilities to manage models
3"""
4from __future__ import print_function
5
6import traceback
7import os
8import sys
9import os.path
10# Time is needed by the log method
11import time
12import datetime
13import logging
14import py_compile
15import shutil
16
17# Explicitly import from the pluginmodel module so that py2exe
18# places it in the distribution. The Model1DPlugin class is used
19# as the base class of plug-in models.
20#from sas.sascalc.fit.pluginmodel import Model1DPlugin
21from sas.qtgui.Utilities.CategoryInstaller import CategoryInstaller
22from sasmodels.sasview_model import load_custom_model, load_standard_models
23
24PLUGIN_DIR = 'plugin_models'
25PLUGIN_LOG = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR,
26                          "plugins.log")
27PLUGIN_NAME_BASE = '[plug-in] '
28
29def get_model_python_path():
30    """
31    Returns the python path for a model
32    """
33    return os.path.dirname(__file__)
34
35
36def plugin_log(message):
37    """
38    Log a message in a file located in the user's home directory
39    """
40    out = open(PLUGIN_LOG, 'a')
41    now = time.time()
42    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
43    out.write("%s: %s\n" % (stamp, message))
44    out.close()
45
46
47def _check_plugin(model, name):
48    """
49    Do some checking before model adding plugins in the list
50
51    :param model: class model to add into the plugin list
52    :param name:name of the module plugin
53
54    :return model: model if valid model or None if not valid
55
56    """
57    #Check if the plugin is of type Model1DPlugin
58    #if not issubclass(model, Model1DPlugin):
59    #    msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
60    #    plugin_log(msg)
61    #    return None
62    if model.__name__ != "Model":
63        msg = "Plugin %s class name must be Model \n" % str(name)
64        plugin_log(msg)
65        return None
66    try:
67        new_instance = model()
68    except:
69        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
70                                                             str(sys.exc_type),
71                                                             sys.exc_info()[1])
72        plugin_log(msg)
73        return None
74
75    if hasattr(new_instance, "function"):
76        try:
77            value = new_instance.function()
78        except:
79            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
80                    (str(name), str(sys.exc_type), sys.exc_info()[1])
81            plugin_log(msg)
82            return None
83    else:
84        msg = "Plugin  %s needs a method called function \n" % str(name)
85        plugin_log(msg)
86        return None
87    return model
88
89
90def find_plugins_dir():
91    """
92    Find path of the plugins directory.
93    The plugin directory is located in the user's home directory.
94    """
95    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
96
97    # If the plugin directory doesn't exist, create it
98    if not os.path.isdir(dir):
99        os.makedirs(dir)
100
101    # Find paths needed
102    try:
103        # For source
104        if os.path.isdir(os.path.dirname(__file__)):
105            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
106        else:
107            raise
108    except:
109        # Check for data path next to exe/zip file.
110        #Look for maximum n_dir up of the current dir to find plugins dir
111        n_dir = 12
112        p_dir = None
113        f_dir = os.path.join(os.path.dirname(__file__))
114        for i in range(n_dir):
115            if i > 1:
116                f_dir, _ = os.path.split(f_dir)
117            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
118            if os.path.isdir(plugin_path):
119                p_dir = plugin_path
120                break
121        if not p_dir:
122            raise
123    # Place example user models as needed
124    if os.path.isdir(p_dir):
125        for file in os.listdir(p_dir):
126            file_path = os.path.join(p_dir, file)
127            if os.path.isfile(file_path):
128                if file.split(".")[-1] == 'py' and\
129                    file.split(".")[0] != '__init__':
130                    if not os.path.isfile(os.path.join(dir, file)):
131                        shutil.copy(file_path, dir)
132
133    return dir
134
135
136class ReportProblem:
137    """
138    Class to check for problems with specific values
139    """
140    def __nonzero__(self):
141        type, value, tb = sys.exc_info()
142        if type is not None and issubclass(type, py_compile.PyCompileError):
143            print("Problem with", repr(value))
144            raise type, value, tb
145        return 1
146
147report_problem = ReportProblem()
148
149
150def compile_file(dir):
151    """
152    Compile a py file
153    """
154    try:
155        import compileall
156        compileall.compile_dir(dir=dir, ddir=dir, force=1,
157                               quiet=report_problem)
158    except:
159        return sys.exc_info()[1]
160    return None
161
162
163def _find_models():
164    """
165    Find custom models
166    """
167    # List of plugin objects
168    directory = find_plugins_dir()
169    # Go through files in plug-in directory
170    if not os.path.isdir(directory):
171        msg = "SasView couldn't locate Model plugin folder %r." % directory
172        logging.warning(msg)
173        return {}
174
175    plugin_log("looking for models in: %s" % str(directory))
176    # compile_file(directory)  #always recompile the folder plugin
177    logging.info("plugin model dir: %s" % str(directory))
178
179    plugins = {}
180    for filename in os.listdir(directory):
181        name, ext = os.path.splitext(filename)
182        if ext == '.py' and not name == '__init__':
183            path = os.path.abspath(os.path.join(directory, filename))
184            try:
185                model = load_custom_model(path)
186                model.name = PLUGIN_NAME_BASE + model.name
187                plugins[model.name] = model
188            except Exception:
189                msg = traceback.format_exc()
190                msg += "\nwhile accessing model in %r" % path
191                plugin_log(msg)
192                logging.warning("Failed to load plugin %r. See %s for details"
193                               % (path, PLUGIN_LOG))
194
195    return plugins
196
197
198class ModelList(object):
199    """
200    Contains dictionary of model and their type
201    """
202    def __init__(self):
203        """
204        """
205        self.mydict = {}
206
207    def set_list(self, name, mylist):
208        """
209        :param name: the type of the list
210        :param mylist: the list to add
211
212        """
213        if name not in self.mydict.keys():
214            self.reset_list(name, mylist)
215
216    def reset_list(self, name, mylist):
217        """
218        :param name: the type of the list
219        :param mylist: the list to add
220        """
221        self.mydict[name] = mylist
222
223    def get_list(self):
224        """
225        return all the list stored in a dictionary object
226        """
227        return self.mydict
228
229
230class ModelManagerBase:
231    """
232    Base class for the model manager
233    """
234    ## external dict for models
235    model_combobox = ModelList()
236    ## Dictionary of form factor models
237    form_factor_dict = {}
238    ## dictionary of structure factor models
239    struct_factor_dict = {}
240    ##list of structure factors
241    struct_list = []
242    ##list of model allowing multiplication by a structure factor
243    multiplication_factor = []
244    ##list of multifunctional shapes (i.e. that have user defined number of levels
245    multi_func_list = []
246    ## list of added models -- currently python models found in the plugin dir.
247    plugins = []
248    ## Event owner (guiframe)
249    event_owner = None
250    last_time_dir_modified = 0
251
252    def __init__(self):
253        self.model_dictionary = {}
254        self.stored_plugins = {}
255        self._getModelList()
256
257    def findModels(self):
258        """
259        find  plugin model in directory of plugin .recompile all file
260        in the directory if file were modified
261        """
262        temp = {}
263        if self.is_changed():
264            return  _find_models()
265        logging.info("plugin model : %s" % str(temp))
266        return temp
267
268    def _getModelList(self):
269        """
270        List of models we want to make available by default
271        for this application
272
273        :return: the next free event ID following the new menu events
274
275        """
276
277        # regular model names only
278        self.model_name_list = []
279
280        #Build list automagically from sasmodels package
281        for model in load_standard_models():
282            self.model_dictionary[model.name] = model
283            if model.is_structure_factor:
284                self.struct_list.append(model)
285            if model.is_form_factor:
286                self.multiplication_factor.append(model)
287            if model.is_multiplicity_model:
288                self.multi_func_list.append(model)
289            else:
290                self.model_name_list.append(model.name)
291
292        #Looking for plugins
293        self.stored_plugins = self.findModels()
294        self.plugins = self.stored_plugins.values()
295        for name, plug in self.stored_plugins.iteritems():
296            self.model_dictionary[name] = plug
297       
298        self._get_multifunc_models()
299
300        return 0
301
302    def is_changed(self):
303        """
304        check the last time the plugin dir has changed and return true
305        is the directory was modified else return false
306        """
307        is_modified = False
308        plugin_dir = find_plugins_dir()
309        if os.path.isdir(plugin_dir):
310            temp = os.path.getmtime(plugin_dir)
311            if  self.last_time_dir_modified != temp:
312                is_modified = True
313                self.last_time_dir_modified = temp
314
315        return is_modified
316
317    def update(self):
318        """
319        return a dictionary of model if
320        new models were added else return empty dictionary
321        """
322        new_plugins = self.findModels()
323        if len(new_plugins) > 0:
324            for name, plug in  new_plugins.iteritems():
325                if name not in self.stored_plugins.keys():
326                    self.stored_plugins[name] = plug
327                    self.plugins.append(plug)
328                    self.model_dictionary[name] = plug
329            self.model_combobox.set_list("Plugin Models", self.plugins)
330            return self.model_combobox.get_list()
331        else:
332            return {}
333
334    def plugins_reset(self):
335        """
336        return a dictionary of model
337        """
338        self.plugins = []
339        new_plugins = _find_models()
340        for name, plug in  new_plugins.iteritems():
341            for stored_name, stored_plug in self.stored_plugins.iteritems():
342                if name == stored_name:
343                    del self.stored_plugins[name]
344                    del self.model_dictionary[name]
345                    break
346            self.stored_plugins[name] = plug
347            self.plugins.append(plug)
348            self.model_dictionary[name] = plug
349
350        self.model_combobox.reset_list("Plugin Models", self.plugins)
351        return self.model_combobox.get_list()
352
353    def _on_model(self, evt):
354        """
355        React to a model menu event
356
357        :param event: wx menu event
358
359        """
360        if int(evt.GetId()) in self.form_factor_dict.keys():
361            from sasmodels.sasview_model import MultiplicationModel
362            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
363            model1, model2 = self.form_factor_dict[int(evt.GetId())]
364            model = MultiplicationModel(model1, model2)
365        else:
366            model = self.struct_factor_dict[str(evt.GetId())]()
367
368
369    def _get_multifunc_models(self):
370        """
371        Get the multifunctional models
372        """
373        items = [item for item in self.plugins if item.is_multiplicity_model]
374        self.multi_func_list = items
375
376    def get_model_list(self):
377        """
378        return dictionary of models for fitpanel use
379
380        """
381        ## Model_list now only contains attribute lists not category list.
382        ## Eventually this should be in one master list -- read in category
383        ## list then pull those models that exist and get attributes then add
384        ## to list ..and if model does not exist remove from list as now
385        ## and update json file.
386        ##
387        ## -PDB   April 26, 2014
388
389#        self.model_combobox.set_list("Shapes", self.shape_list)
390#        self.model_combobox.set_list("Shape-Independent",
391#                                     self.shape_indep_list)
392        self.model_combobox.set_list("Structure Factors", self.struct_list)
393        self.model_combobox.set_list("Plugin Models", self.plugins)
394        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
395        self.model_combobox.set_list("multiplication",
396                                     self.multiplication_factor)
397        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
398        return self.model_combobox.get_list()
399
400    def get_model_name_list(self):
401        """
402        return regular model name list
403        """
404        return self.model_name_list
405
406    def get_model_dictionary(self):
407        """
408        return dictionary linking model names to objects
409        """
410        return self.model_dictionary
411
412
413class ModelManager(object):
414    """
415    implement model
416    """
417    __modelmanager = ModelManagerBase()
418    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
419                      in __modelmanager.model_dictionary.keys() \
420                      if model_name not in __modelmanager.stored_plugins.keys()]
421
422    CategoryInstaller.check_install(model_list=cat_model_list)
423    def findModels(self):
424        return self.__modelmanager.findModels()
425
426    def _getModelList(self):
427        return self.__modelmanager._getModelList()
428
429    def is_changed(self):
430        return self.__modelmanager.is_changed()
431
432    def update(self):
433        return self.__modelmanager.update()
434
435    def plugins_reset(self):
436        return self.__modelmanager.plugins_reset()
437
438    def populate_menu(self, modelmenu, event_owner):
439        return self.__modelmanager.populate_menu(modelmenu, event_owner)
440
441    def _on_model(self, evt):
442        return self.__modelmanager._on_model(evt)
443
444    def _get_multifunc_models(self):
445        return self.__modelmanager._get_multifunc_models()
446
447    def get_model_list(self):
448        return self.__modelmanager.get_model_list()
449
450    def get_model_name_list(self):
451        return self.__modelmanager.get_model_name_list()
452
453    def get_model_dictionary(self):
454        return self.__modelmanager.get_model_dictionary()
455
Note: See TracBrowser for help on using the repository browser.