source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ dcdca68

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since dcdca68 was dcdca68, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

adjust to new sasmodels interface

  • Property mode set to 100644
File size: 15.1 KB
Line 
1"""
2    Utilities to manage models
3"""
4import imp
5import os
6import sys
7import os.path
8# Time is needed by the log method
9import time
10import logging
11import py_compile
12import shutil
13# Explicitly import from the pluginmodel module so that py2exe
14# places it in the distribution. The Model1DPlugin class is used
15# as the base class of plug-in models.
16from sas.sascalc.fit.pluginmodel import Model1DPlugin
17from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
18from sasmodels.sasview_model import load_custom_model, load_standard_models
19
20
21PLUGIN_DIR = 'plugin_models'
22
23def get_model_python_path():
24    """
25    Returns the python path for a model
26    """
27    return os.path.dirname(__file__)
28
29
30def log(message):
31    """
32    Log a message in a file located in the user's home directory
33    """
34    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
35    out = open(os.path.join(dir, "plugins.log"), 'a')
36    out.write("%10g%s\n" % (time.clock(), message))
37    out.close()
38
39
40def _check_plugin(model, name):
41    """
42    Do some checking before model adding plugins in the list
43
44    :param model: class model to add into the plugin list
45    :param name:name of the module plugin
46
47    :return model: model if valid model or None if not valid
48
49    """
50    #Check if the plugin is of type Model1DPlugin
51    if not issubclass(model, Model1DPlugin):
52        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
53        log(msg)
54        return None
55    if model.__name__ != "Model":
56        msg = "Plugin %s class name must be Model \n" % str(name)
57        log(msg)
58        return None
59    try:
60        new_instance = model()
61    except:
62        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
63                                                             str(sys.exc_type),
64                                                             sys.exc_info()[1])
65        log(msg)
66        return None
67
68    if hasattr(new_instance, "function"):
69        try:
70            value = new_instance.function()
71        except:
72            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
73                    (str(name), str(sys.exc_type), sys.exc_info()[1])
74            log(msg)
75            return None
76    else:
77        msg = "Plugin  %s needs a method called function \n" % str(name)
78        log(msg)
79        return None
80    return model
81
82
83def find_plugins_dir():
84    """
85    Find path of the plugins directory.
86    The plugin directory is located in the user's home directory.
87    """
88    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
89
90    # If the plugin directory doesn't exist, create it
91    if not os.path.isdir(dir):
92        os.makedirs(dir)
93
94    # Find paths needed
95    try:
96        # For source
97        if os.path.isdir(os.path.dirname(__file__)):
98            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
99        else:
100            raise
101    except:
102        # Check for data path next to exe/zip file.
103        #Look for maximum n_dir up of the current dir to find plugins dir
104        n_dir = 12
105        p_dir = None
106        f_dir = os.path.join(os.path.dirname(__file__))
107        for i in range(n_dir):
108            if i > 1:
109                f_dir, _ = os.path.split(f_dir)
110            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
111            if os.path.isdir(plugin_path):
112                p_dir = plugin_path
113                break
114        if not p_dir:
115            raise
116    # Place example user models as needed
117    if os.path.isdir(p_dir):
118        for file in os.listdir(p_dir):
119            file_path = os.path.join(p_dir, file)
120            if os.path.isfile(file_path):
121                if file.split(".")[-1] == 'py' and\
122                    file.split(".")[0] != '__init__':
123                    if not os.path.isfile(os.path.join(dir, file)):
124                        shutil.copy(file_path, dir)
125
126    return dir
127
128
129class ReportProblem:
130    """
131    Class to check for problems with specific values
132    """
133    def __nonzero__(self):
134        type, value, traceback = sys.exc_info()
135        if type is not None and issubclass(type, py_compile.PyCompileError):
136            print "Problem with", repr(value)
137            raise type, value, traceback
138        return 1
139
140report_problem = ReportProblem()
141
142
143def compile_file(dir):
144    """
145    Compile a py file
146    """
147    try:
148        import compileall
149        compileall.compile_dir(dir=dir, ddir=dir, force=1,
150                               quiet=report_problem)
151    except:
152        return sys.exc_info()[1]
153    return None
154
155
156def _findModels(dir):
157    # List of plugin objects
158    plugins = {}
159    dir = find_plugins_dir()
160    # Go through files in plug-in directory
161    #always recompile the folder plugin
162    if not os.path.isdir(dir):
163        msg = "SasView couldn't locate Model plugin folder."
164        msg += """ "%s" does not exist""" % dir
165        logging.warning(msg)
166        return plugins
167    else:
168        log("looking for models in: %s" % str(dir))
169        compile_file(dir)
170        logging.info("plugin model dir: %s" % str(dir))
171    try:
172        list = os.listdir(dir)
173        for item in list:
174            toks = os.path.splitext(os.path.basename(item))
175            if toks[1] == '.py' and not toks[0] == '__init__':
176                name = toks[0]
177                path = [os.path.abspath(dir)]
178                file = None
179                try:
180                    (file, path, info) = imp.find_module(name, path)
181                    module = imp.load_module(name, file, item, info)
182                    if hasattr(module, "Model"):
183                        try:
184                            if _check_plugin(module.Model, name) != None:
185                                plugins[name] = module.Model
186                        except:
187                            msg = "Error accessing Model"
188                            msg += "in %s\n  %s %s\n" % (name,
189                                                         str(sys.exc_type),
190                                                         sys.exc_info()[1])
191                            log(msg)
192                    else:
193                        filename = os.path.join(dir, item)
194                        plugins[name] = load_custom_model(filename)
195
196                except:
197                    msg = "Error accessing Model"
198                    msg += " in %s\n  %s %s \n" % (name,
199                                                   str(sys.exc_type),
200                                                   sys.exc_info()[1])
201                    log(msg)
202                finally:
203
204                    if not file == None:
205                        file.close()
206    except:
207        # Don't deal with bad plug-in imports. Just skip.
208        msg = "Could not import model plugin: %s" % sys.exc_info()[1]
209        log(msg)
210
211    return plugins
212
213
214class ModelList(object):
215    """
216    Contains dictionary of model and their type
217    """
218    def __init__(self):
219        """
220        """
221        self.mydict = {}
222
223    def set_list(self, name, mylist):
224        """
225        :param name: the type of the list
226        :param mylist: the list to add
227
228        """
229        if name not in self.mydict.keys():
230            self.reset_list(name, mylist)
231
232    def reset_list(self, name, mylist):
233        """
234        :param name: the type of the list
235        :param mylist: the list to add
236        """
237        self.mydict[name] = mylist
238
239    def get_list(self):
240        """
241        return all the list stored in a dictionary object
242        """
243        return self.mydict
244
245
246class ModelManagerBase:
247    """
248    Base class for the model manager
249    """
250    ## external dict for models
251    model_combobox = ModelList()
252    ## Dictionary of form factor models
253    form_factor_dict = {}
254    ## dictionary of structure factor models
255    struct_factor_dict = {}
256    ##list of structure factors
257    struct_list = []
258    ##list of model allowing multiplication by a structure factor
259    multiplication_factor = []
260    ##list of multifunctional shapes (i.e. that have user defined number of levels
261    multi_func_list = []
262    ## list of added models -- currently python models found in the plugin dir.
263    plugins = []
264    ## Event owner (guiframe)
265    event_owner = None
266    last_time_dir_modified = 0
267
268    def __init__(self):
269        self.model_dictionary = {}
270        self.stored_plugins = {}
271        self._getModelList()
272
273    def findModels(self):
274        """
275        find  plugin model in directory of plugin .recompile all file
276        in the directory if file were modified
277        """
278        temp = {}
279        if self.is_changed():
280            return  _findModels(dir)
281        logging.info("plugin model : %s" % str(temp))
282        return temp
283
284    def _getModelList(self):
285        """
286        List of models we want to make available by default
287        for this application
288
289        :return: the next free event ID following the new menu events
290
291        """
292
293        # regular model names only
294        self.model_name_list = []
295
296        #Build list automagically from sasmodels package
297        for model in load_standard_models():
298            self.model_dictionary[model._model_info['name']] = model
299            if model._model_info['structure_factor'] == True:
300                self.struct_list.append(model)
301            if model._model_info['variant_info'] is not None:
302                self.multi_func_list.append(model)
303            else:
304                self.model_name_list.append(model._model_info['name'])
305            if model._model_info['ER'] is not None:
306                self.multiplication_factor.append(model)
307
308        #Looking for plugins
309        self.stored_plugins = self.findModels()
310        self.plugins = self.stored_plugins.values()
311        for name, plug in self.stored_plugins.iteritems():
312            self.model_dictionary[name] = plug
313       
314        self._get_multifunc_models()
315
316        return 0
317
318    def is_changed(self):
319        """
320        check the last time the plugin dir has changed and return true
321        is the directory was modified else return false
322        """
323        is_modified = False
324        plugin_dir = find_plugins_dir()
325        if os.path.isdir(plugin_dir):
326            temp = os.path.getmtime(plugin_dir)
327            if  self.last_time_dir_modified != temp:
328                is_modified = True
329                self.last_time_dir_modified = temp
330
331        return is_modified
332
333    def update(self):
334        """
335        return a dictionary of model if
336        new models were added else return empty dictionary
337        """
338        new_plugins = self.findModels()
339        if len(new_plugins) > 0:
340            for name, plug in  new_plugins.iteritems():
341                if name not in self.stored_plugins.keys():
342                    self.stored_plugins[name] = plug
343                    self.plugins.append(plug)
344                    self.model_dictionary[name] = plug
345            self.model_combobox.set_list("Customized Models", self.plugins)
346            return self.model_combobox.get_list()
347        else:
348            return {}
349
350    def plugins_reset(self):
351        """
352        return a dictionary of model
353        """
354        self.plugins = []
355        new_plugins = _findModels(dir)
356        for name, plug in  new_plugins.iteritems():
357            for stored_name, stored_plug in self.stored_plugins.iteritems():
358                if name == stored_name:
359                    del self.stored_plugins[name]
360                    del self.model_dictionary[name]
361                    break
362            self.stored_plugins[name] = plug
363            self.plugins.append(plug)
364            self.model_dictionary[name] = plug
365
366        self.model_combobox.reset_list("Customized Models", self.plugins)
367        return self.model_combobox.get_list()
368
369    def _on_model(self, evt):
370        """
371        React to a model menu event
372
373        :param event: wx menu event
374
375        """
376        if int(evt.GetId()) in self.form_factor_dict.keys():
377            from sas.sascalc.fit.MultiplicationModel import MultiplicationModel
378            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
379            model1, model2 = self.form_factor_dict[int(evt.GetId())]
380            model = MultiplicationModel(model1, model2)
381        else:
382            model = self.struct_factor_dict[str(evt.GetId())]()
383
384
385    def _get_multifunc_models(self):
386        """
387        Get the multifunctional models
388        """
389        for item in self.plugins:
390            try:
391                # check the multiplicity if any
392                if item.multiplicity_info[0] > 1:
393                    self.multi_func_list.append(item)
394            except:
395                # pass to other items
396                pass
397
398    def get_model_list(self):
399        """
400        return dictionary of models for fitpanel use
401
402        """
403        ## Model_list now only contains attribute lists not category list.
404        ## Eventually this should be in one master list -- read in category
405        ## list then pull those models that exist and get attributes then add
406        ## to list ..and if model does not exist remove from list as now
407        ## and update json file.
408        ##
409        ## -PDB   April 26, 2014
410
411#        self.model_combobox.set_list("Shapes", self.shape_list)
412#        self.model_combobox.set_list("Shape-Independent",
413#                                     self.shape_indep_list)
414        self.model_combobox.set_list("Structure Factors", self.struct_list)
415        self.model_combobox.set_list("Customized Models", self.plugins)
416        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
417        self.model_combobox.set_list("multiplication",
418                                     self.multiplication_factor)
419        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
420        return self.model_combobox.get_list()
421
422    def get_model_name_list(self):
423        """
424        return regular model name list
425        """
426        return self.model_name_list
427
428    def get_model_dictionary(self):
429        """
430        return dictionary linking model names to objects
431        """
432        return self.model_dictionary
433
434
435class ModelManager(object):
436    """
437    implement model
438    """
439    __modelmanager = ModelManagerBase()
440    cat_model_list = [model_name for model_name \
441                      in __modelmanager.model_dictionary.keys() \
442                      if model_name not in __modelmanager.stored_plugins.keys()]
443
444    CategoryInstaller.check_install(model_list=cat_model_list)
445    def findModels(self):
446        return self.__modelmanager.findModels()
447
448    def _getModelList(self):
449        return self.__modelmanager._getModelList()
450
451    def is_changed(self):
452        return self.__modelmanager.is_changed()
453
454    def update(self):
455        return self.__modelmanager.update()
456
457    def plugins_reset(self):
458        return self.__modelmanager.plugins_reset()
459
460    def populate_menu(self, modelmenu, event_owner):
461        return self.__modelmanager.populate_menu(modelmenu, event_owner)
462
463    def _on_model(self, evt):
464        return self.__modelmanager._on_model(evt)
465
466    def _get_multifunc_models(self):
467        return self.__modelmanager._get_multifunc_models()
468
469    def get_model_list(self):
470        return self.__modelmanager.get_model_list()
471
472    def get_model_name_list(self):
473        return self.__modelmanager.get_model_name_list()
474
475    def get_model_dictionary(self):
476        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.