source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ 8f46df7

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8f46df7 was ab3ed7e, checked in by gonzalezm, 9 years ago

Discover and use python and python+C models added to the user plugin_models folder

  • Property mode set to 100644
File size: 15.3 KB
Line 
1"""
2    Utilities to manage models
3"""
4import wx
5import imp
6import os
7import sys
8import math
9import os.path
10# Time is needed by the log method
11import time
12import logging
13import py_compile
14import shutil
15from sas.sasgui.guiframe.events import StatusEvent
16# Explicitly import from the pluginmodel module so that py2exe
17# places it in the distribution. The Model1DPlugin class is used
18# as the base class of plug-in models.
19from sas.models.pluginmodel import Model1DPlugin
20from sas.models.BaseComponent import BaseComponent
21from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
22from sasmodels import sasview_model,core
23
24
25PLUGIN_DIR = 'plugin_models'
26
27def get_model_python_path():
28    """
29        Returns the python path for a model
30    """
31    return os.path.dirname(__file__)
32
33
34def log(message):
35    """
36        Log a message in a file located in the user's home directory
37    """
38    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
39    out = open(os.path.join(dir, "plugins.log"), 'a')
40    out.write("%10g%s\n" % (time.clock(), message))
41    out.close()
42
43
44def _check_plugin(model, name):
45    """
46    Do some checking before model adding plugins in the list
47
48    :param model: class model to add into the plugin list
49    :param name:name of the module plugin
50
51    :return model: model if valid model or None if not valid
52
53    """
54    #Check if the plugin is of type Model1DPlugin
55    if not issubclass(model, Model1DPlugin):
56        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
57        log(msg)
58        return None
59    if model.__name__ != "Model":
60        msg = "Plugin %s class name must be Model \n" % str(name)
61        log(msg)
62        return None
63    try:
64        new_instance = model()
65    except:
66        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
67                                                             str(sys.exc_type),
68                                                             sys.exc_info()[1])
69        log(msg)
70        return None
71
72    if hasattr(new_instance, "function"):
73        try:
74            value = new_instance.function()
75        except:
76            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
77                    (str(name), str(sys.exc_type), sys.exc_info()[1])
78            log(msg)
79            return None
80    else:
81        msg = "Plugin  %s needs a method called function \n" % str(name)
82        log(msg)
83        return None
84    return model
85
86
87def find_plugins_dir():
88    """
89        Find path of the plugins directory.
90        The plugin directory is located in the user's home directory.
91    """
92    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
93
94    # If the plugin directory doesn't exist, create it
95    if not os.path.isdir(dir):
96        os.makedirs(dir)
97
98    # Find paths needed
99    try:
100        # For source
101        if os.path.isdir(os.path.dirname(__file__)):
102            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
103        else:
104            raise
105    except:
106        # Check for data path next to exe/zip file.
107        #Look for maximum n_dir up of the current dir to find plugins dir
108        n_dir = 12
109        p_dir = None
110        f_dir = os.path.join(os.path.dirname(__file__))
111        for i in range(n_dir):
112            if i > 1:
113                f_dir, _ = os.path.split(f_dir)
114            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
115            if os.path.isdir(plugin_path):
116                p_dir = plugin_path
117                break
118        if not p_dir:
119            raise
120    # Place example user models as needed
121    if os.path.isdir(p_dir):
122        for file in os.listdir(p_dir):
123            file_path = os.path.join(p_dir, file)
124            if os.path.isfile(file_path):
125                if file.split(".")[-1] == 'py' and\
126                    file.split(".")[0] != '__init__':
127                    if not os.path.isfile(os.path.join(dir, file)):
128                        shutil.copy(file_path, dir)
129
130    return dir
131
132
133class ReportProblem:
134    """
135        Class to check for problems with specific values
136    """
137    def __nonzero__(self):
138        type, value, traceback = sys.exc_info()
139        if type is not None and issubclass(type, py_compile.PyCompileError):
140            print "Problem with", repr(value)
141            raise type, value, traceback
142        return 1
143
144report_problem = ReportProblem()
145
146
147def compile_file(dir):
148    """
149    Compile a py file
150    """
151    try:
152        import compileall
153        compileall.compile_dir(dir=dir, ddir=dir, force=1,
154                               quiet=report_problem)
155    except:
156        return sys.exc_info()[1]
157    return None
158
159
160def _findModels(dir):
161    """
162    """
163    # List of plugin objects
164    plugins = {}
165    dir = find_plugins_dir()
166    # Go through files in plug-in directory
167    #always recompile the folder plugin
168    if not os.path.isdir(dir):
169        msg = "SasView couldn't locate Model plugin folder."
170        msg += """ "%s" does not exist""" % dir
171        logging.warning(msg)
172        return plugins
173    else:
174        log("looking for models in: %s" % str(dir))
175        compile_file(dir)
176        logging.info("plugin model dir: %s" % str(dir))
177    try:
178        list = os.listdir(dir)
179        for item in list:
180            toks = os.path.splitext(os.path.basename(item))
181            if toks[1] == '.py' and not toks[0] == '__init__':
182                name = toks[0]
183                path = [os.path.abspath(dir)]
184                file = None
185                try:
186                    (file, path, info) = imp.find_module(name, path)
187                    module = imp.load_module(name, file, item, info)
188                    if hasattr(module, "Model"):
189                        try:
190                            if _check_plugin(module.Model, name) != None:
191                                plugins[name] = module.Model
192                        except:
193                            msg = "Error accessing Model"
194                            msg += "in %s\n  %s %s\n" % (name,
195                                                         str(sys.exc_type),
196                                                         sys.exc_info()[1])
197                            log(msg)
198                    else:
199                        filename = os.path.join(dir, item)
200                        plugins[name] = sasview_model.make_class_from_file(filename) 
201
202                except:
203                    msg = "Error accessing Model"
204                    msg += " in %s\n  %s %s \n" % (name,
205                                                   str(sys.exc_type),
206                                                   sys.exc_info()[1])
207                    log(msg)
208                finally:
209
210                    if not file == None:
211                        file.close()
212    except:
213        # Don't deal with bad plug-in imports. Just skip.
214        msg = "Could not import model plugin: %s" % sys.exc_info()[1]
215        log(msg)
216
217    return plugins
218
219
220class ModelList(object):
221    """
222    Contains dictionary of model and their type
223    """
224    def __init__(self):
225        """
226        """
227        self.mydict = {}
228
229    def set_list(self, name, mylist):
230        """
231        :param name: the type of the list
232        :param mylist: the list to add
233
234        """
235        if name not in self.mydict.keys():
236            self.reset_list(name, mylist)
237
238    def reset_list(self, name, mylist):
239        """
240        :param name: the type of the list
241        :param mylist: the list to add
242        """
243        self.mydict[name] = mylist
244
245    def get_list(self):
246        """
247        return all the list stored in a dictionary object
248        """
249        return self.mydict
250
251
252class ModelManagerBase:
253    """
254        Base class for the model manager
255    """
256    ## external dict for models
257    model_combobox = ModelList()
258    ## Dictionary of form factor models
259    form_factor_dict = {}
260    ## dictionary of structure factor models
261    struct_factor_dict = {}
262    ##list of structure factors
263    struct_list = []
264    ##list of model allowing multiplication by a structure factor
265    multiplication_factor = []
266    ##list of multifunctional shapes (i.e. that have user defined number of levels
267    multi_func_list = []
268    ## list of added models -- currently python models found in the plugin dir.
269    plugins = []
270    ## Event owner (guiframe)
271    event_owner = None
272    last_time_dir_modified = 0
273
274    def __init__(self):
275        """
276        """
277        self.model_dictionary = {}
278        self.stored_plugins = {}
279        self._getModelList()
280
281    def findModels(self):
282        """
283        find  plugin model in directory of plugin .recompile all file
284        in the directory if file were modified
285        """
286        temp = {}
287        if self.is_changed():
288            return  _findModels(dir)
289        logging.info("plugin model : %s" % str(temp))
290        return temp
291
292    def _getModelList(self):
293        """
294        List of models we want to make available by default
295        for this application
296
297        :return: the next free event ID following the new menu events
298
299        """
300
301        # regular model names only
302        base_message = "Unable to load model {0}"
303        self.model_name_list = []
304
305        #Build list automagically from sasmodels package
306        for model in sasview_model.standard_models():
307            self.model_dictionary[model._model_info['name']] = model
308            if model._model_info['structure_factor'] == True:
309                self.struct_list.append(model)
310            if model._model_info['variant_info'] is not None:
311                self.multi_func_list.append(model)
312            else:
313                self.model_name_list.append(model._model_info['name'])
314            if model._model_info['ER'] is not None:
315                self.multiplication_factor.append(model)
316
317        #Looking for plugins
318        self.stored_plugins = self.findModels()
319        self.plugins = self.stored_plugins.values()
320        for name, plug in self.stored_plugins.iteritems():
321            self.model_dictionary[name] = plug
322       
323        self._get_multifunc_models()
324
325        return 0
326
327    def is_changed(self):
328        """
329        check the last time the plugin dir has changed and return true
330         is the directory was modified else return false
331        """
332        is_modified = False
333        plugin_dir = find_plugins_dir()
334        if os.path.isdir(plugin_dir):
335            temp = os.path.getmtime(plugin_dir)
336            if  self.last_time_dir_modified != temp:
337                is_modified = True
338                self.last_time_dir_modified = temp
339
340        return is_modified
341
342    def update(self):
343        """
344        return a dictionary of model if
345        new models were added else return empty dictionary
346        """
347        new_plugins = self.findModels()
348        if len(new_plugins) > 0:
349            for name, plug in  new_plugins.iteritems():
350                if name not in self.stored_plugins.keys():
351                    self.stored_plugins[name] = plug
352                    self.plugins.append(plug)
353                    self.model_dictionary[name] = plug
354            self.model_combobox.set_list("Customized Models", self.plugins)
355            return self.model_combobox.get_list()
356        else:
357            return {}
358
359    def plugins_reset(self):
360        """
361        return a dictionary of model
362        """
363        self.plugins = []
364        new_plugins = _findModels(dir)
365        for name, plug in  new_plugins.iteritems():
366            for stored_name, stored_plug in self.stored_plugins.iteritems():
367                if name == stored_name:
368                    del self.stored_plugins[name]
369                    del self.model_dictionary[name]
370                    break
371            self.stored_plugins[name] = plug
372            self.plugins.append(plug)
373            self.model_dictionary[name] = plug
374
375        self.model_combobox.reset_list("Customized Models", self.plugins)
376        return self.model_combobox.get_list()
377
378    def _on_model(self, evt):
379        """
380        React to a model menu event
381
382        :param event: wx menu event
383
384        """
385        if int(evt.GetId()) in self.form_factor_dict.keys():
386            from sas.models.MultiplicationModel import MultiplicationModel
387            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
388            model1, model2 = self.form_factor_dict[int(evt.GetId())]
389            model = MultiplicationModel(model1, model2)
390        else:
391            model = self.struct_factor_dict[str(evt.GetId())]()
392
393
394    def _get_multifunc_models(self):
395        """
396        Get the multifunctional models
397        """
398        for item in self.plugins:
399            try:
400                # check the multiplicity if any
401                if item.multiplicity_info[0] > 1:
402                    self.multi_func_list.append(item)
403            except:
404                # pass to other items
405                pass
406
407    def get_model_list(self):
408        """
409        return dictionary of models for fitpanel use
410
411        """
412        ## Model_list now only contains attribute lists not category list.
413        ## Eventually this should be in one master list -- read in category
414        ## list then pull those models that exist and get attributes then add
415        ## to list ..and if model does not exist remove from list as now
416        ## and update json file.
417        ##
418        ## -PDB   April 26, 2014
419
420#        self.model_combobox.set_list("Shapes", self.shape_list)
421#        self.model_combobox.set_list("Shape-Independent",
422#                                     self.shape_indep_list)
423        self.model_combobox.set_list("Structure Factors", self.struct_list)
424        self.model_combobox.set_list("Customized Models", self.plugins)
425        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
426        self.model_combobox.set_list("multiplication",
427                                     self.multiplication_factor)
428        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
429        return self.model_combobox.get_list()
430
431    def get_model_name_list(self):
432        """
433        return regular model name list
434        """
435        return self.model_name_list
436
437    def get_model_dictionary(self):
438        """
439        return dictionary linking model names to objects
440        """
441        return self.model_dictionary
442
443
444class ModelManager(object):
445    """
446    implement model
447    """
448    __modelmanager = ModelManagerBase()
449    cat_model_list = [model_name for model_name \
450                      in __modelmanager.model_dictionary.keys() \
451                      if model_name not in __modelmanager.stored_plugins.keys()]
452
453    CategoryInstaller.check_install(model_list=cat_model_list)
454    def findModels(self):
455        return self.__modelmanager.findModels()
456
457    def _getModelList(self):
458        return self.__modelmanager._getModelList()
459
460    def is_changed(self):
461        return self.__modelmanager.is_changed()
462
463    def update(self):
464        return self.__modelmanager.update()
465
466    def plugins_reset(self):
467        return self.__modelmanager.plugins_reset()
468
469    def populate_menu(self, modelmenu, event_owner):
470        return self.__modelmanager.populate_menu(modelmenu, event_owner)
471
472    def _on_model(self, evt):
473        return self.__modelmanager._on_model(evt)
474
475    def _get_multifunc_models(self):
476        return self.__modelmanager._get_multifunc_models()
477
478    def get_model_list(self):
479        return self.__modelmanager.get_model_list()
480
481    def get_model_name_list(self):
482        return self.__modelmanager.get_model_name_list()
483
484    def get_model_dictionary(self):
485        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.