source: sasview/src/sas/sasgui/perspectives/fitting/models.py @ 724af06

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 724af06 was 724af06, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

Merge branch 'master' into ticket-887-reorg

  • Property mode set to 100644
File size: 16.3 KB
Line 
1"""
2    Utilities to manage models
3"""
4from __future__ import print_function
5
6import traceback
7import os
8import sys
9import os.path
10# Time is needed by the log method
11import time
12import datetime
13import logging
14import py_compile
15import shutil
16from copy import copy
17
18from sasmodels.sasview_model import load_custom_model, load_standard_models
19
20# Explicitly import from the pluginmodel module so that py2exe
21# places it in the distribution. The Model1DPlugin class is used
22# as the base class of plug-in models.
23from sas.sasgui import get_user_dir
24from sas.sascalc.fit.pluginmodel import Model1DPlugin
25from sas.sasgui.guiframe.CategoryInstaller import CategoryInstaller
26from sas.sasgui.perspectives.fitting.fitpage import CUSTOM_MODEL
27
28logger = logging.getLogger(__name__)
29
30
31PLUGIN_DIR = 'plugin_models'
32PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log")
33PLUGIN_NAME_BASE = '[plug-in] '
34
35def get_model_python_path():
36    """
37    Returns the python path for a model
38    """
39    return os.path.dirname(__file__)
40
41
42def plugin_log(message):
43    """
44    Log a message in a file located in the user's home directory
45    """
46    out = open(PLUGIN_LOG, 'a')
47    now = time.time()
48    stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S')
49    out.write("%s: %s\n" % (stamp, message))
50    out.close()
51
52
53def _check_plugin(model, name):
54    """
55    Do some checking before model adding plugins in the list
56
57    :param model: class model to add into the plugin list
58    :param name:name of the module plugin
59
60    :return model: model if valid model or None if not valid
61
62    """
63    #Check if the plugin is of type Model1DPlugin
64    if not issubclass(model, Model1DPlugin):
65        msg = "Plugin %s must be of type Model1DPlugin \n" % str(name)
66        plugin_log(msg)
67        return None
68    if model.__name__ != "Model":
69        msg = "Plugin %s class name must be Model \n" % str(name)
70        plugin_log(msg)
71        return None
72    try:
73        new_instance = model()
74    except:
75        msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name),
76                                                             str(sys.exc_type),
77                                                             sys.exc_info()[1])
78        plugin_log(msg)
79        return None
80
81    if hasattr(new_instance, "function"):
82        try:
83            value = new_instance.function()
84        except:
85            msg = "Plugin %s: error writing function \n\t :%s %s\n " % \
86                    (str(name), str(sys.exc_type), sys.exc_info()[1])
87            plugin_log(msg)
88            return None
89    else:
90        msg = "Plugin  %s needs a method called function \n" % str(name)
91        plugin_log(msg)
92        return None
93    return model
94
95
96def find_plugins_dir():
97    """
98    Find path of the plugins directory.
99    The plugin directory is located in the user's home directory.
100    """
101    dir = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR)
102
103    # If the plugin directory doesn't exist, create it
104    if not os.path.isdir(dir):
105        os.makedirs(dir)
106
107    # Find paths needed
108    try:
109        # For source
110        if os.path.isdir(os.path.dirname(__file__)):
111            p_dir = os.path.join(os.path.dirname(__file__), PLUGIN_DIR)
112        else:
113            raise
114    except:
115        # Check for data path next to exe/zip file.
116        #Look for maximum n_dir up of the current dir to find plugins dir
117        n_dir = 12
118        p_dir = None
119        f_dir = os.path.join(os.path.dirname(__file__))
120        for i in range(n_dir):
121            if i > 1:
122                f_dir, _ = os.path.split(f_dir)
123            plugin_path = os.path.join(f_dir, PLUGIN_DIR)
124            if os.path.isdir(plugin_path):
125                p_dir = plugin_path
126                break
127        if not p_dir:
128            raise
129    # Place example user models as needed
130    if os.path.isdir(p_dir):
131        for file in os.listdir(p_dir):
132            file_path = os.path.join(p_dir, file)
133            if os.path.isfile(file_path):
134                if file.split(".")[-1] == 'py' and\
135                    file.split(".")[0] != '__init__':
136                    if not os.path.isfile(os.path.join(dir, file)):
137                        shutil.copy(file_path, dir)
138
139    return dir
140
141
142class ReportProblem:
143    """
144    Class to check for problems with specific values
145    """
146    def __nonzero__(self):
147        type, value, tb = sys.exc_info()
148        if type is not None and issubclass(type, py_compile.PyCompileError):
149            print("Problem with", repr(value))
150            raise type, value, tb
151        return 1
152
153report_problem = ReportProblem()
154
155
156def compile_file(dir):
157    """
158    Compile a py file
159    """
160    try:
161        import compileall
162        compileall.compile_dir(dir=dir, ddir=dir, force=0,
163                               quiet=report_problem)
164    except:
165        return sys.exc_info()[1]
166    return None
167
168
169def _find_models():
170    """
171    Find custom models
172    """
173    # List of plugin objects
174    directory = find_plugins_dir()
175    # Go through files in plug-in directory
176    if not os.path.isdir(directory):
177        msg = "SasView couldn't locate Model plugin folder %r." % directory
178        logger.warning(msg)
179        return {}
180
181    plugin_log("looking for models in: %s" % str(directory))
182    # compile_file(directory)  #always recompile the folder plugin
183    logger.info("plugin model dir: %s" % str(directory))
184
185    plugins = {}
186    for filename in os.listdir(directory):
187        name, ext = os.path.splitext(filename)
188        if ext == '.py' and not name == '__init__':
189            path = os.path.abspath(os.path.join(directory, filename))
190            try:
191                model = load_custom_model(path)
192                if not model.name.count(PLUGIN_NAME_BASE):
193                    model.name = PLUGIN_NAME_BASE + model.name
194                plugins[model.name] = model
195            except Exception:
196                msg = traceback.format_exc()
197                msg += "\nwhile accessing model in %r" % path
198                plugin_log(msg)
199                logger.warning("Failed to load plugin %r. See %s for details"
200                               % (path, PLUGIN_LOG))
201
202    return plugins
203
204
205class ModelList(object):
206    """
207    Contains dictionary of model and their type
208    """
209    def __init__(self):
210        """
211        """
212        self.mydict = {}
213
214    def set_list(self, name, mylist):
215        """
216        :param name: the type of the list
217        :param mylist: the list to add
218
219        """
220        if name not in self.mydict.keys():
221            self.reset_list(name, mylist)
222
223    def reset_list(self, name, mylist):
224        """
225        :param name: the type of the list
226        :param mylist: the list to add
227        """
228        self.mydict[name] = mylist
229
230    def get_list(self):
231        """
232        return all the list stored in a dictionary object
233        """
234        return self.mydict
235
236
237class ModelManagerBase:
238    """
239    Base class for the model manager
240    """
241    ## external dict for models
242    model_combobox = ModelList()
243    ## Dictionary of form factor models
244    form_factor_dict = {}
245    ## dictionary of structure factor models
246    struct_factor_dict = {}
247    ##list of structure factors
248    struct_list = []
249    ##list of model allowing multiplication by a structure factor
250    multiplication_factor = []
251    ##list of multifunctional shapes (i.e. that have user defined number of levels
252    multi_func_list = []
253    ## list of added models -- currently python models found in the plugin dir.
254    plugins = []
255    ## Event owner (guiframe)
256    event_owner = None
257    last_time_dir_modified = 0
258
259    def __init__(self):
260        self.model_dictionary = {}
261        self.stored_plugins = {}
262        self._getModelList()
263
264    def findModels(self):
265        """
266        find  plugin model in directory of plugin .recompile all file
267        in the directory if file were modified
268        """
269        temp = {}
270        if self.is_changed():
271            temp =  _find_models()
272            self.last_time_dir_modified = time.time()
273            return temp
274        logger.info("plugin model : %s" % str(temp))
275        return temp
276
277    def _getModelList(self):
278        """
279        List of models we want to make available by default
280        for this application
281
282        :return: the next free event ID following the new menu events
283
284        """
285
286        # Regular model names only
287        self.model_name_list = []
288
289        # Build list automagically from sasmodels package
290        for model in load_standard_models():
291            self.model_dictionary[model.name] = model
292            if model.is_structure_factor:
293                self.struct_list.append(model)
294            if model.is_form_factor:
295                self.multiplication_factor.append(model)
296            if model.is_multiplicity_model:
297                self.multi_func_list.append(model)
298            else:
299                self.model_name_list.append(model.name)
300
301        # Looking for plugins
302        self.stored_plugins = self.findModels()
303        self.plugins = self.stored_plugins.values()
304        for name, plug in self.stored_plugins.iteritems():
305            self.model_dictionary[name] = plug
306            # TODO: Remove 'hasattr' statements when old style plugin models
307            # are no longer supported. All sasmodels models will have
308            # the required attributes.
309            if hasattr(plug, 'is_structure_factor') and plug.is_structure_factor:
310                self.struct_list.append(plug)
311                self.plugins.remove(plug)
312            elif hasattr(plug, 'is_form_factor') and plug.is_form_factor:
313                self.multiplication_factor.append(plug)
314            if hasattr(plug, 'is_multiplicity_model') and plug.is_multiplicity_model:
315                self.multi_func_list.append(plug)
316
317        self._get_multifunc_models()
318
319        return 0
320
321    def is_changed(self):
322        """
323        check the last time the plugin dir has changed and return true
324        is the directory was modified else return false
325        """
326        is_modified = False
327        plugin_dir = find_plugins_dir()
328        if os.path.isdir(plugin_dir):
329            temp = os.path.getmtime(plugin_dir)
330            if  self.last_time_dir_modified < temp:
331                is_modified = True
332                self.last_time_dir_modified = temp
333
334        return is_modified
335
336    def update(self):
337        """
338        return a dictionary of model if
339        new models were added else return empty dictionary
340        """
341        self.plugins = []
342        new_plugins = self.findModels()
343        if new_plugins:
344            for name, plug in  new_plugins.items():
345                self.stored_plugins[name] = plug
346                self.plugins.append(plug)
347                self.model_dictionary[name] = plug
348            self.model_combobox.set_list(CUSTOM_MODEL, self.plugins)
349            return self.model_combobox.get_list()
350        else:
351            return {}
352
353    def plugins_reset(self):
354        """
355        return a dictionary of model
356        """
357        self.plugins = []
358        self.stored_plugins = _find_models()
359        structure_names = [model.name for model in self.struct_list]
360        form_names = [model.name for model in self.multiplication_factor]
361
362        # Remove all plugin structure factors and form factors
363        for name in copy(structure_names):
364            if '[plug-in]' in name:
365                i = structure_names.index(name)
366                del self.struct_list[i]
367                structure_names.remove(name)
368        for name in copy(form_names):
369            if '[plug-in]' in name:
370                i = form_names.index(name)
371                del self.multiplication_factor[i]
372                form_names.remove(name)
373
374        # Add new plugin structure factors and form factors
375        for name, plug in self.stored_plugins.iteritems():
376            if plug.is_structure_factor:
377                if name in structure_names:
378                    # Delete the old model from self.struct list
379                    i = structure_names.index(name)
380                    del self.struct_list[i]
381                # Add the new model to self.struct_list
382                self.struct_list.append(plug)
383            elif plug.is_form_factor:
384                if name in form_names:
385                    # Delete the old model from self.multiplication_factor
386                    i = form_names.index(name)
387                    del self.multiplication_factor[i]
388                # Add the new model to self.multiplication_factor
389                self.multiplication_factor.append(plug)
390
391            # Add references to the updated model
392            self.stored_plugins[name] = plug
393            if not plug.is_structure_factor:
394                # Don't show S(Q) models in the 'Plugin Models' dropdown
395                self.plugins.append(plug)
396            self.model_dictionary[name] = plug
397
398        self.model_combobox.reset_list("Plugin Models", self.plugins)
399        self.model_combobox.reset_list("Structure Factors", self.struct_list)
400        self.model_combobox.reset_list("P(Q)*S(Q)", self.multiplication_factor)
401
402        return self.model_combobox.get_list()
403
404    def _on_model(self, evt):
405        """
406        React to a model menu event
407
408        :param event: wx menu event
409
410        """
411        if int(evt.GetId()) in self.form_factor_dict.keys():
412            from sasmodels.sasview_model import MultiplicationModel
413            self.model_dictionary[MultiplicationModel.__name__] = MultiplicationModel
414            model1, model2 = self.form_factor_dict[int(evt.GetId())]
415            model = MultiplicationModel(model1, model2)
416        else:
417            model = self.struct_factor_dict[str(evt.GetId())]()
418
419
420    def _get_multifunc_models(self):
421        """
422        Get the multifunctional models
423        """
424        items = [item for item in self.plugins if item.is_multiplicity_model]
425        self.multi_func_list = items
426
427    def get_model_list(self):
428        """
429        return dictionary of models for fitpanel use
430
431        """
432        ## Model_list now only contains attribute lists not category list.
433        ## Eventually this should be in one master list -- read in category
434        ## list then pull those models that exist and get attributes then add
435        ## to list ..and if model does not exist remove from list as now
436        ## and update json file.
437        ##
438        ## -PDB   April 26, 2014
439
440#        self.model_combobox.set_list("Shapes", self.shape_list)
441#        self.model_combobox.set_list("Shape-Independent",
442#                                     self.shape_indep_list)
443        self.model_combobox.set_list("Structure Factors", self.struct_list)
444        self.model_combobox.set_list("Plugin Models", self.plugins)
445        self.model_combobox.set_list("P(Q)*S(Q)", self.multiplication_factor)
446        self.model_combobox.set_list("multiplication",
447                                     self.multiplication_factor)
448        self.model_combobox.set_list("Multi-Functions", self.multi_func_list)
449        return self.model_combobox.get_list()
450
451    def get_model_name_list(self):
452        """
453        return regular model name list
454        """
455        return self.model_name_list
456
457    def get_model_dictionary(self):
458        """
459        return dictionary linking model names to objects
460        """
461        return self.model_dictionary
462
463
464class ModelManager(object):
465    """
466    implement model
467    """
468    __modelmanager = ModelManagerBase()
469    cat_model_list = [__modelmanager.model_dictionary[model_name] for model_name \
470                      in __modelmanager.model_dictionary.keys() \
471                      if model_name not in __modelmanager.stored_plugins.keys()]
472
473    CategoryInstaller.check_install(model_list=cat_model_list)
474    def findModels(self):
475        return self.__modelmanager.findModels()
476
477    def _getModelList(self):
478        return self.__modelmanager._getModelList()
479
480    def is_changed(self):
481        return self.__modelmanager.is_changed()
482
483    def update(self):
484        return self.__modelmanager.update()
485
486    def plugins_reset(self):
487        return self.__modelmanager.plugins_reset()
488
489    def populate_menu(self, modelmenu, event_owner):
490        return self.__modelmanager.populate_menu(modelmenu, event_owner)
491
492    def _on_model(self, evt):
493        return self.__modelmanager._on_model(evt)
494
495    def _get_multifunc_models(self):
496        return self.__modelmanager._get_multifunc_models()
497
498    def get_model_list(self):
499        return self.__modelmanager.get_model_list()
500
501    def get_model_name_list(self):
502        return self.__modelmanager.get_model_name_list()
503
504    def get_model_dictionary(self):
505        return self.__modelmanager.get_model_dictionary()
Note: See TracBrowser for help on using the repository browser.