source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ e28f34d

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e28f34d was 212bfc2, checked in by mathieu, 8 years ago

Pull categories from models. Get rid of default categories. Fixes #535

  • Property mode set to 100644
File size: 13.9 KB
Line 
1"""
2    Utilities to manage models
3"""
4import traceback
5import os
6import sys
7import os.path
8# Time is needed by the log method
9import time
10import datetime
11import logging
12import py_compile
13import shutil
14# Explicitly import from the pluginmodel module so that py2exe
15# places it in the distribution. The Model1DPlugin class is used
16# as the base class of plug-in models.
17from sas.sascalc.fit.pluginmodel import Model1DPlugin
18from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
19from sasmodels.sasview_model import load_custom_model, load_standard_models
20
21
22PLUGIN_DIR = 'plugin_models'
23PLUGIN_LOG = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR,
24                          "plugins.log")
25
26def get_model_python_path():
27    """
28    Returns the python path for a model
29    """
30    return os.path.dirname(__file__)
31
32
33def plugin_log(message):
34    """
35    Log a message in a file located in the user's home directory
36    """
37    out = open(PLUGIN_LOG, 'a')
38    now = time.time()
39    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
40    out.write("%s: %s\n" % (stamp, message))
41    out.close()
42
43
44def _check_plugin(model, name):
45    """
46    Do some checking before model adding plugins in the list
47
48    :param model: class model to add into the plugin list
49    :param name:name of the module plugin
50
51    :return model: model if valid model or None if not valid
52
53    """
54    #Check if the plugin is of type Model1DPlugin
55    if not issubclass(model, Model1DPlugin):
56        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
57        plugin_log(msg)
58        return None
59    if model.__name__ != "Model":
60        msg = "Plugin %s class name must be Model \n" % str(name)
61        plugin_log(msg)
62        return None
63    try:
64        new_instance = model()
65    except:
66        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
67                                                             str(sys.exc_type),
68                                                             sys.exc_info()[1])
69        plugin_log(msg)
70        return None
71
72    if hasattr(new_instance, "function"):
73        try:
74            value = new_instance.function()
75        except:
76            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
77                    (str(name), str(sys.exc_type), sys.exc_info()[1])
78            plugin_log(msg)
79            return None
80    else:
81        msg = "Plugin  %s needs a method called function \n" % str(name)
82        plugin_log(msg)
83        return None
84    return model
85
86
87def find_plugins_dir():
88    """
89    Find path of the plugins directory.
90    The plugin directory is located in the user's home directory.
91    """
92    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
93
94    # If the plugin directory doesn't exist, create it
95    if not os.path.isdir(dir):
96        os.makedirs(dir)
97
98    # Find paths needed
99    try:
100        # For source
101        if os.path.isdir(os.path.dirname(__file__)):
102            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
103        else:
104            raise
105    except:
106        # Check for data path next to exe/zip file.
107        #Look for maximum n_dir up of the current dir to find plugins dir
108        n_dir = 12
109        p_dir = None
110        f_dir = os.path.join(os.path.dirname(__file__))
111        for i in range(n_dir):
112            if i > 1:
113                f_dir, _ = os.path.split(f_dir)
114            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
115            if os.path.isdir(plugin_path):
116                p_dir = plugin_path
117                break
118        if not p_dir:
119            raise
120    # Place example user models as needed
121    if os.path.isdir(p_dir):
122        for file in os.listdir(p_dir):
123            file_path = os.path.join(p_dir, file)
124            if os.path.isfile(file_path):
125                if file.split(".")[-1] == 'py' and\
126                    file.split(".")[0] != '__init__':
127                    if not os.path.isfile(os.path.join(dir, file)):
128                        shutil.copy(file_path, dir)
129
130    return dir
131
132
133class ReportProblem:
134    """
135    Class to check for problems with specific values
136    """
137    def __nonzero__(self):
138        type, value, tb = sys.exc_info()
139        if type is not None and issubclass(type, py_compile.PyCompileError):
140            print "Problem with", repr(value)
141            raise type, value, tb
142        return 1
143
144report_problem = ReportProblem()
145
146
147def compile_file(dir):
148    """
149    Compile a py file
150    """
151    try:
152        import compileall
153        compileall.compile_dir(dir=dir, ddir=dir, force=1,
154                               quiet=report_problem)
155    except:
156        return sys.exc_info()[1]
157    return None
158
159
160def _findModels(dir):
161    """
162    Find custom models
163    """
164    # List of plugin objects
165    dir = find_plugins_dir()
166    # Go through files in plug-in directory
167    if not os.path.isdir(dir):
168        msg = "SasView couldn't locate Model plugin folder %r." % dir
169        logging.warning(msg)
170        return {}
171
172    plugin_log("looking for models in: %s" % str(dir))
173    #compile_file(dir)  #always recompile the folder plugin
174    logging.info("plugin model dir: %s" % str(dir))
175
176    plugins = {}
177    for filename in os.listdir(dir):
178        name, ext = os.path.splitext(filename)
179        if ext == '.py' and not name == '__init__':
180            path = os.path.abspath(os.path.join(dir, filename))
181            try:
182                model = load_custom_model(path)
183                plugins[model.name] = model
184            except Exception:
185                msg = traceback.format_exc()
186                msg += "\nwhile accessing model in %r" % path
187                plugin_log(msg)
188                logging.warning("Failed to load plugin %r. See %s for details"
189                                % (path, PLUGIN_LOG))
190           
191    return plugins
192
193
194class ModelList(object):
195    """
196    Contains dictionary of model and their type
197    """
198    def __init__(self):
199        """
200        """
201        self.mydict = {}
202
203    def set_list(self, name, mylist):
204        """
205        :param name: the type of the list
206        :param mylist: the list to add
207
208        """
209        if name not in self.mydict.keys():
210            self.reset_list(name, mylist)
211
212    def reset_list(self, name, mylist):
213        """
214        :param name: the type of the list
215        :param mylist: the list to add
216        """
217        self.mydict[name] = mylist
218
219    def get_list(self):
220        """
221        return all the list stored in a dictionary object
222        """
223        return self.mydict
224
225
226class ModelManagerBase:
227    """
228    Base class for the model manager
229    """
230    ## external dict for models
231    model_combobox = ModelList()
232    ## Dictionary of form factor models
233    form_factor_dict = {}
234    ## dictionary of structure factor models
235    struct_factor_dict = {}
236    ##list of structure factors
237    struct_list = []
238    ##list of model allowing multiplication by a structure factor
239    multiplication_factor = []
240    ##list of multifunctional shapes (i.e. that have user defined number of levels
241    multi_func_list = []
242    ## list of added models -- currently python models found in the plugin dir.
243    plugins = []
244    ## Event owner (guiframe)
245    event_owner = None
246    last_time_dir_modified = 0
247
248    def __init__(self):
249        self.model_dictionary = {}
250        self.stored_plugins = {}
251        self._getModelList()
252
253    def findModels(self):
254        """
255        find  plugin model in directory of plugin .recompile all file
256        in the directory if file were modified
257        """
258        temp = {}
259        if self.is_changed():
260            return  _findModels(dir)
261        logging.info("plugin model : %s" % str(temp))
262        return temp
263
264    def _getModelList(self):
265        """
266        List of models we want to make available by default
267        for this application
268
269        :return: the next free event ID following the new menu events
270
271        """
272
273        # regular model names only
274        self.model_name_list = []
275
276        #Build list automagically from sasmodels package
277        for model in load_standard_models():
278            self.model_dictionary[model.name] = model
279            if model.is_structure_factor:
280                self.struct_list.append(model)
281            if model.is_form_factor:
282                self.multiplication_factor.append(model)
283            if model.is_multiplicity_model:
284                self.multi_func_list.append(model)
285            else:
286                self.model_name_list.append(model.name)
287
288        #Looking for plugins
289        self.stored_plugins = self.findModels()
290        self.plugins = self.stored_plugins.values()
291        for name, plug in self.stored_plugins.iteritems():
292            self.model_dictionary[name] = plug
293       
294        self._get_multifunc_models()
295
296        return 0
297
298    def is_changed(self):
299        """
300        check the last time the plugin dir has changed and return true
301        is the directory was modified else return false
302        """
303        is_modified = False
304        plugin_dir = find_plugins_dir()
305        if os.path.isdir(plugin_dir):
306            temp = os.path.getmtime(plugin_dir)
307            if  self.last_time_dir_modified != temp:
308                is_modified = True
309                self.last_time_dir_modified = temp
310
311        return is_modified
312
313    def update(self):
314        """
315        return a dictionary of model if
316        new models were added else return empty dictionary
317        """
318        new_plugins = self.findModels()
319        if len(new_plugins) > 0:
320            for name, plug in  new_plugins.iteritems():
321                if name not in self.stored_plugins.keys():
322                    self.stored_plugins[name] = plug
323                    self.plugins.append(plug)
324                    self.model_dictionary[name] = plug
325            self.model_combobox.set_list("Customized Models", self.plugins)
326            return self.model_combobox.get_list()
327        else:
328            return {}
329
330    def plugins_reset(self):
331        """
332        return a dictionary of model
333        """
334        self.plugins = []
335        new_plugins = _findModels(dir)
336        for name, plug in  new_plugins.iteritems():
337            for stored_name, stored_plug in self.stored_plugins.iteritems():
338                if name == stored_name:
339                    del self.stored_plugins[name]
340                    del self.model_dictionary[name]
341                    break
342            self.stored_plugins[name] = plug
343            self.plugins.append(plug)
344            self.model_dictionary[name] = plug
345
346        self.model_combobox.reset_list("Customized Models", self.plugins)
347        return self.model_combobox.get_list()
348
349    def _on_model(self, evt):
350        """
351        React to a model menu event
352
353        :param event: wx menu event
354
355        """
356        if int(evt.GetId()) in self.form_factor_dict.keys():
357            from sas.sascalc.fit.MultiplicationModel import MultiplicationModel
358            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
359            model1, model2 = self.form_factor_dict[int(evt.GetId())]
360            model = MultiplicationModel(model1, model2)
361        else:
362            model = self.struct_factor_dict[str(evt.GetId())]()
363
364
365    def _get_multifunc_models(self):
366        """
367        Get the multifunctional models
368        """
369        items = [item for item in self.plugins if item.is_multiplicity_model]
370        self.multi_func_list = items
371
372    def get_model_list(self):
373        """
374        return dictionary of models for fitpanel use
375
376        """
377        ## Model_list now only contains attribute lists not category list.
378        ## Eventually this should be in one master list -- read in category
379        ## list then pull those models that exist and get attributes then add
380        ## to list ..and if model does not exist remove from list as now
381        ## and update json file.
382        ##
383        ## -PDB   April 26, 2014
384
385#        self.model_combobox.set_list("Shapes", self.shape_list)
386#        self.model_combobox.set_list("Shape-Independent",
387#                                     self.shape_indep_list)
388        self.model_combobox.set_list("Structure Factors", self.struct_list)
389        self.model_combobox.set_list("Customized Models", self.plugins)
390        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
391        self.model_combobox.set_list("multiplication",
392                                     self.multiplication_factor)
393        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
394        return self.model_combobox.get_list()
395
396    def get_model_name_list(self):
397        """
398        return regular model name list
399        """
400        return self.model_name_list
401
402    def get_model_dictionary(self):
403        """
404        return dictionary linking model names to objects
405        """
406        return self.model_dictionary
407
408
409class ModelManager(object):
410    """
411    implement model
412    """
413    __modelmanager = ModelManagerBase()
414    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
415                      in __modelmanager.model_dictionary.keys() \
416                      if model_name not in __modelmanager.stored_plugins.keys()]
417
418    CategoryInstaller.check_install(model_list=cat_model_list)
419    def findModels(self):
420        return self.__modelmanager.findModels()
421
422    def _getModelList(self):
423        return self.__modelmanager._getModelList()
424
425    def is_changed(self):
426        return self.__modelmanager.is_changed()
427
428    def update(self):
429        return self.__modelmanager.update()
430
431    def plugins_reset(self):
432        return self.__modelmanager.plugins_reset()
433
434    def populate_menu(self, modelmenu, event_owner):
435        return self.__modelmanager.populate_menu(modelmenu, event_owner)
436
437    def _on_model(self, evt):
438        return self.__modelmanager._on_model(evt)
439
440    def _get_multifunc_models(self):
441        return self.__modelmanager._get_multifunc_models()
442
443    def get_model_list(self):
444        return self.__modelmanager.get_model_list()
445
446    def get_model_name_list(self):
447        return self.__modelmanager.get_model_name_list()
448
449    def get_model_dictionary(self):
450        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.