source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ feb62c6

Last change on this file since feb62c6 was 11b094f, checked in by butler, 8 years ago

Merge remote-tracking branch 'origin/master' into ticket-756

  • Property mode set to 100644
File size: 14.0 KB
Line 
1"""
2    Utilities to manage models
3"""
4import traceback
5import os
6import sys
7import os.path
8# Time is needed by the log method
9import time
10import datetime
11import logging
12import py_compile
13import shutil
14# Explicitly import from the pluginmodel module so that py2exe
15# places it in the distribution. The Model1DPlugin class is used
16# as the base class of plug-in models.
17from sas.sascalc.fit.pluginmodel import Model1DPlugin
18from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
19from sasmodels.sasview_model import load_custom_model, load_standard_models
20
21
22PLUGIN_DIR = 'plugin_models'
23PLUGIN_LOG = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR,
24                          "plugins.log")
25PLUGIN_NAME_BASE = '[plug-in] '
26
27def get_model_python_path():
28    """
29    Returns the python path for a model
30    """
31    return os.path.dirname(__file__)
32
33
34def plugin_log(message):
35    """
36    Log a message in a file located in the user's home directory
37    """
38    out = open(PLUGIN_LOG, 'a')
39    now = time.time()
40    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
41    out.write("%s: %s\n" % (stamp, message))
42    out.close()
43
44
45def _check_plugin(model, name):
46    """
47    Do some checking before model adding plugins in the list
48
49    :param model: class model to add into the plugin list
50    :param name:name of the module plugin
51
52    :return model: model if valid model or None if not valid
53
54    """
55    #Check if the plugin is of type Model1DPlugin
56    if not issubclass(model, Model1DPlugin):
57        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
58        plugin_log(msg)
59        return None
60    if model.__name__ != "Model":
61        msg = "Plugin %s class name must be Model \n" % str(name)
62        plugin_log(msg)
63        return None
64    try:
65        new_instance = model()
66    except:
67        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
68                                                             str(sys.exc_type),
69                                                             sys.exc_info()[1])
70        plugin_log(msg)
71        return None
72
73    if hasattr(new_instance, "function"):
74        try:
75            value = new_instance.function()
76        except:
77            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
78                    (str(name), str(sys.exc_type), sys.exc_info()[1])
79            plugin_log(msg)
80            return None
81    else:
82        msg = "Plugin  %s needs a method called function \n" % str(name)
83        plugin_log(msg)
84        return None
85    return model
86
87
88def find_plugins_dir():
89    """
90    Find path of the plugins directory.
91    The plugin directory is located in the user's home directory.
92    """
93    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
94
95    # If the plugin directory doesn't exist, create it
96    if not os.path.isdir(dir):
97        os.makedirs(dir)
98
99    # Find paths needed
100    try:
101        # For source
102        if os.path.isdir(os.path.dirname(__file__)):
103            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
104        else:
105            raise
106    except:
107        # Check for data path next to exe/zip file.
108        #Look for maximum n_dir up of the current dir to find plugins dir
109        n_dir = 12
110        p_dir = None
111        f_dir = os.path.join(os.path.dirname(__file__))
112        for i in range(n_dir):
113            if i > 1:
114                f_dir, _ = os.path.split(f_dir)
115            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
116            if os.path.isdir(plugin_path):
117                p_dir = plugin_path
118                break
119        if not p_dir:
120            raise
121    # Place example user models as needed
122    if os.path.isdir(p_dir):
123        for file in os.listdir(p_dir):
124            file_path = os.path.join(p_dir, file)
125            if os.path.isfile(file_path):
126                if file.split(".")[-1] == 'py' and\
127                    file.split(".")[0] != '__init__':
128                    if not os.path.isfile(os.path.join(dir, file)):
129                        shutil.copy(file_path, dir)
130
131    return dir
132
133
134class ReportProblem:
135    """
136    Class to check for problems with specific values
137    """
138    def __nonzero__(self):
139        type, value, tb = sys.exc_info()
140        if type is not None and issubclass(type, py_compile.PyCompileError):
141            print "Problem with", repr(value)
142            raise type, value, tb
143        return 1
144
145report_problem = ReportProblem()
146
147
148def compile_file(dir):
149    """
150    Compile a py file
151    """
152    try:
153        import compileall
154        compileall.compile_dir(dir=dir, ddir=dir, force=1,
155                               quiet=report_problem)
156    except:
157        return sys.exc_info()[1]
158    return None
159
160
161def _findModels(dir):
162    """
163    Find custom models
164    """
165    # List of plugin objects
166    dir = find_plugins_dir()
167    # Go through files in plug-in directory
168    if not os.path.isdir(dir):
169        msg = "SasView couldn't locate Model plugin folder %r." % dir
170        logging.warning(msg)
171        return {}
172
173    plugin_log("looking for models in: %s" % str(dir))
174    #compile_file(dir)  #always recompile the folder plugin
175    logging.info("plugin model dir: %s" % str(dir))
176
177    plugins = {}
178    for filename in os.listdir(dir):
179        name, ext = os.path.splitext(filename)
180        if ext == '.py' and not name == '__init__':
181            path = os.path.abspath(os.path.join(dir, filename))
182            try:
183                model = load_custom_model(path)
184                model.name = PLUGIN_NAME_BASE + model.name
185                plugins[model.name] = model
186            except Exception:
187                msg = traceback.format_exc()
188                msg += "\nwhile accessing model in %r" % path
189                plugin_log(msg)
190                logging.warning("Failed to load plugin %r. See %s for details"
191                                % (path, PLUGIN_LOG))
192           
193    return plugins
194
195
196class ModelList(object):
197    """
198    Contains dictionary of model and their type
199    """
200    def __init__(self):
201        """
202        """
203        self.mydict = {}
204
205    def set_list(self, name, mylist):
206        """
207        :param name: the type of the list
208        :param mylist: the list to add
209
210        """
211        if name not in self.mydict.keys():
212            self.reset_list(name, mylist)
213
214    def reset_list(self, name, mylist):
215        """
216        :param name: the type of the list
217        :param mylist: the list to add
218        """
219        self.mydict[name] = mylist
220
221    def get_list(self):
222        """
223        return all the list stored in a dictionary object
224        """
225        return self.mydict
226
227
228class ModelManagerBase:
229    """
230    Base class for the model manager
231    """
232    ## external dict for models
233    model_combobox = ModelList()
234    ## Dictionary of form factor models
235    form_factor_dict = {}
236    ## dictionary of structure factor models
237    struct_factor_dict = {}
238    ##list of structure factors
239    struct_list = []
240    ##list of model allowing multiplication by a structure factor
241    multiplication_factor = []
242    ##list of multifunctional shapes (i.e. that have user defined number of levels
243    multi_func_list = []
244    ## list of added models -- currently python models found in the plugin dir.
245    plugins = []
246    ## Event owner (guiframe)
247    event_owner = None
248    last_time_dir_modified = 0
249
250    def __init__(self):
251        self.model_dictionary = {}
252        self.stored_plugins = {}
253        self._getModelList()
254
255    def findModels(self):
256        """
257        find  plugin model in directory of plugin .recompile all file
258        in the directory if file were modified
259        """
260        temp = {}
261        if self.is_changed():
262            return  _findModels(dir)
263        logging.info("plugin model : %s" % str(temp))
264        return temp
265
266    def _getModelList(self):
267        """
268        List of models we want to make available by default
269        for this application
270
271        :return: the next free event ID following the new menu events
272
273        """
274
275        # regular model names only
276        self.model_name_list = []
277
278        #Build list automagically from sasmodels package
279        for model in load_standard_models():
280            self.model_dictionary[model.name] = model
281            if model.is_structure_factor:
282                self.struct_list.append(model)
283            if model.is_form_factor:
284                self.multiplication_factor.append(model)
285            if model.is_multiplicity_model:
286                self.multi_func_list.append(model)
287            else:
288                self.model_name_list.append(model.name)
289
290        #Looking for plugins
291        self.stored_plugins = self.findModels()
292        self.plugins = self.stored_plugins.values()
293        for name, plug in self.stored_plugins.iteritems():
294            self.model_dictionary[name] = plug
295       
296        self._get_multifunc_models()
297
298        return 0
299
300    def is_changed(self):
301        """
302        check the last time the plugin dir has changed and return true
303        is the directory was modified else return false
304        """
305        is_modified = False
306        plugin_dir = find_plugins_dir()
307        if os.path.isdir(plugin_dir):
308            temp = os.path.getmtime(plugin_dir)
309            if  self.last_time_dir_modified != temp:
310                is_modified = True
311                self.last_time_dir_modified = temp
312
313        return is_modified
314
315    def update(self):
316        """
317        return a dictionary of model if
318        new models were added else return empty dictionary
319        """
320        new_plugins = self.findModels()
321        if len(new_plugins) > 0:
322            for name, plug in  new_plugins.iteritems():
323                if name not in self.stored_plugins.keys():
324                    self.stored_plugins[name] = plug
325                    self.plugins.append(plug)
326                    self.model_dictionary[name] = plug
327            self.model_combobox.set_list("Plugin Models", self.plugins)
328            return self.model_combobox.get_list()
329        else:
330            return {}
331
332    def plugins_reset(self):
333        """
334        return a dictionary of model
335        """
336        self.plugins = []
337        new_plugins = _findModels(dir)
338        for name, plug in  new_plugins.iteritems():
339            for stored_name, stored_plug in self.stored_plugins.iteritems():
340                if name == stored_name:
341                    del self.stored_plugins[name]
342                    del self.model_dictionary[name]
343                    break
344            self.stored_plugins[name] = plug
345            self.plugins.append(plug)
346            self.model_dictionary[name] = plug
347
348        self.model_combobox.reset_list("Plugin Models", self.plugins)
349        return self.model_combobox.get_list()
350
351    def _on_model(self, evt):
352        """
353        React to a model menu event
354
355        :param event: wx menu event
356
357        """
358        if int(evt.GetId()) in self.form_factor_dict.keys():
359            from sasmodels.sasview_model import MultiplicationModel
360            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
361            model1, model2 = self.form_factor_dict[int(evt.GetId())]
362            model = MultiplicationModel(model1, model2)
363        else:
364            model = self.struct_factor_dict[str(evt.GetId())]()
365
366
367    def _get_multifunc_models(self):
368        """
369        Get the multifunctional models
370        """
371        items = [item for item in self.plugins if item.is_multiplicity_model]
372        self.multi_func_list = items
373
374    def get_model_list(self):
375        """
376        return dictionary of models for fitpanel use
377
378        """
379        ## Model_list now only contains attribute lists not category list.
380        ## Eventually this should be in one master list -- read in category
381        ## list then pull those models that exist and get attributes then add
382        ## to list ..and if model does not exist remove from list as now
383        ## and update json file.
384        ##
385        ## -PDB   April 26, 2014
386
387#        self.model_combobox.set_list("Shapes", self.shape_list)
388#        self.model_combobox.set_list("Shape-Independent",
389#                                     self.shape_indep_list)
390        self.model_combobox.set_list("Structure Factors", self.struct_list)
391        self.model_combobox.set_list("Plugin Models", self.plugins)
392        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
393        self.model_combobox.set_list("multiplication",
394                                     self.multiplication_factor)
395        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
396        return self.model_combobox.get_list()
397
398    def get_model_name_list(self):
399        """
400        return regular model name list
401        """
402        return self.model_name_list
403
404    def get_model_dictionary(self):
405        """
406        return dictionary linking model names to objects
407        """
408        return self.model_dictionary
409
410
411class ModelManager(object):
412    """
413    implement model
414    """
415    __modelmanager = ModelManagerBase()
416    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
417                      in __modelmanager.model_dictionary.keys() \
418                      if model_name not in __modelmanager.stored_plugins.keys()]
419
420    CategoryInstaller.check_install(model_list=cat_model_list)
421    def findModels(self):
422        return self.__modelmanager.findModels()
423
424    def _getModelList(self):
425        return self.__modelmanager._getModelList()
426
427    def is_changed(self):
428        return self.__modelmanager.is_changed()
429
430    def update(self):
431        return self.__modelmanager.update()
432
433    def plugins_reset(self):
434        return self.__modelmanager.plugins_reset()
435
436    def populate_menu(self, modelmenu, event_owner):
437        return self.__modelmanager.populate_menu(modelmenu, event_owner)
438
439    def _on_model(self, evt):
440        return self.__modelmanager._on_model(evt)
441
442    def _get_multifunc_models(self):
443        return self.__modelmanager._get_multifunc_models()
444
445    def get_model_list(self):
446        return self.__modelmanager.get_model_list()
447
448    def get_model_name_list(self):
449        return self.__modelmanager.get_model_name_list()
450
451    def get_model_dictionary(self):
452        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.