1 | """ |
---|
2 | Utilities to manage models |
---|
3 | """ |
---|
4 | from __future__ import print_function |
---|
5 | |
---|
6 | import os |
---|
7 | import sys |
---|
8 | import time |
---|
9 | import datetime |
---|
10 | import logging |
---|
11 | import traceback |
---|
12 | import py_compile |
---|
13 | import shutil |
---|
14 | |
---|
15 | from sasmodels.sasview_model import load_custom_model, load_standard_models |
---|
16 | |
---|
17 | from sas import get_user_dir |
---|
18 | |
---|
19 | # Explicitly import from the pluginmodel module so that py2exe |
---|
20 | # places it in the distribution. The Model1DPlugin class is used |
---|
21 | # as the base class of plug-in models. |
---|
22 | from .pluginmodel import Model1DPlugin |
---|
23 | |
---|
24 | logger = logging.getLogger(__name__) |
---|
25 | |
---|
26 | |
---|
27 | PLUGIN_DIR = 'plugin_models' |
---|
28 | PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log") |
---|
29 | PLUGIN_NAME_BASE = '[plug-in] ' |
---|
30 | |
---|
31 | |
---|
32 | def plugin_log(message): |
---|
33 | """ |
---|
34 | Log a message in a file located in the user's home directory |
---|
35 | """ |
---|
36 | out = open(PLUGIN_LOG, 'a') |
---|
37 | now = time.time() |
---|
38 | stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S') |
---|
39 | out.write("%s: %s\n" % (stamp, message)) |
---|
40 | out.close() |
---|
41 | |
---|
42 | |
---|
43 | def _check_plugin(model, name): |
---|
44 | """ |
---|
45 | Do some checking before model adding plugins in the list |
---|
46 | |
---|
47 | :param model: class model to add into the plugin list |
---|
48 | :param name:name of the module plugin |
---|
49 | |
---|
50 | :return model: model if valid model or None if not valid |
---|
51 | |
---|
52 | """ |
---|
53 | #Check if the plugin is of type Model1DPlugin |
---|
54 | if not issubclass(model, Model1DPlugin): |
---|
55 | msg = "Plugin %s must be of type Model1DPlugin \n" % str(name) |
---|
56 | plugin_log(msg) |
---|
57 | return None |
---|
58 | if model.__name__ != "Model": |
---|
59 | msg = "Plugin %s class name must be Model \n" % str(name) |
---|
60 | plugin_log(msg) |
---|
61 | return None |
---|
62 | try: |
---|
63 | new_instance = model() |
---|
64 | except Exception: |
---|
65 | msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name), |
---|
66 | str(sys.exc_type), |
---|
67 | sys.exc_info()[1]) |
---|
68 | plugin_log(msg) |
---|
69 | return None |
---|
70 | |
---|
71 | if hasattr(new_instance, "function"): |
---|
72 | try: |
---|
73 | value = new_instance.function() |
---|
74 | except Exception: |
---|
75 | msg = "Plugin %s: error writing function \n\t :%s %s\n " % \ |
---|
76 | (str(name), str(sys.exc_type), sys.exc_info()[1]) |
---|
77 | plugin_log(msg) |
---|
78 | return None |
---|
79 | else: |
---|
80 | msg = "Plugin %s needs a method called function \n" % str(name) |
---|
81 | plugin_log(msg) |
---|
82 | return None |
---|
83 | return model |
---|
84 | |
---|
85 | |
---|
86 | def find_plugins_dir(): |
---|
87 | """ |
---|
88 | Find path of the plugins directory. |
---|
89 | The plugin directory is located in the user's home directory. |
---|
90 | """ |
---|
91 | path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) |
---|
92 | |
---|
93 | # TODO: trigger initialization of plugins dir from installer or startup |
---|
94 | # If the plugin directory doesn't exist, create it |
---|
95 | if not os.path.isdir(path): |
---|
96 | os.makedirs(path) |
---|
97 | # TODO: should we be checking for new default models every time? |
---|
98 | # TODO: restore support for default plugins |
---|
99 | #initialize_plugins_dir(path) |
---|
100 | return path |
---|
101 | |
---|
102 | |
---|
103 | def initialize_plugins_dir(path): |
---|
104 | # TODO: There are no default plugins |
---|
105 | # TODO: Default plugins directory is in sasgui, but models.py is in sascalc |
---|
106 | # TODO: Move default plugins beside sample data files |
---|
107 | # TODO: Should not look for defaults above the root of the sasview install |
---|
108 | |
---|
109 | # Walk up the tree looking for default plugin_models directory |
---|
110 | base = os.path.abspath(os.path.dirname(__file__)) |
---|
111 | for _ in range(12): |
---|
112 | default_plugins_path = os.path.join(base, PLUGIN_DIR) |
---|
113 | if os.path.isdir(default_plugins_path): |
---|
114 | break |
---|
115 | base, _ = os.path.split(base) |
---|
116 | else: |
---|
117 | logger.error("default plugins directory not found") |
---|
118 | return |
---|
119 | |
---|
120 | # Copy files from default plugins to the .sasview directory |
---|
121 | # This may include c files, depending on the example. |
---|
122 | # Note: files are never replaced, even if the default plugins are updated |
---|
123 | for filename in os.listdir(default_plugins_path): |
---|
124 | # skip __init__.py and all pyc files |
---|
125 | if filename == "__init__.py" or filename.endswith('.pyc'): |
---|
126 | continue |
---|
127 | source = os.path.join(default_plugins_path, filename) |
---|
128 | target = os.path.join(path, filename) |
---|
129 | if os.path.isfile(source) and not os.path.isfile(target): |
---|
130 | shutil.copy(source, target) |
---|
131 | |
---|
132 | |
---|
133 | class ReportProblem(object): |
---|
134 | """ |
---|
135 | Class to check for problems with specific values |
---|
136 | """ |
---|
137 | def __nonzero__(self): |
---|
138 | type, value, tb = sys.exc_info() |
---|
139 | if type is not None and issubclass(type, py_compile.PyCompileError): |
---|
140 | print("Problem with", repr(value)) |
---|
141 | raise type, value, tb |
---|
142 | return 1 |
---|
143 | |
---|
144 | report_problem = ReportProblem() |
---|
145 | |
---|
146 | |
---|
147 | def compile_file(dir): |
---|
148 | """ |
---|
149 | Compile a py file |
---|
150 | """ |
---|
151 | try: |
---|
152 | import compileall |
---|
153 | compileall.compile_dir(dir=dir, ddir=dir, force=0, |
---|
154 | quiet=report_problem) |
---|
155 | except Exception: |
---|
156 | return sys.exc_info()[1] |
---|
157 | return None |
---|
158 | |
---|
159 | |
---|
160 | def find_plugin_models(): |
---|
161 | """ |
---|
162 | Find custom models |
---|
163 | """ |
---|
164 | # List of plugin objects |
---|
165 | plugins_dir = find_plugins_dir() |
---|
166 | # Go through files in plug-in directory |
---|
167 | if not os.path.isdir(plugins_dir): |
---|
168 | msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir |
---|
169 | logger.warning(msg) |
---|
170 | return {} |
---|
171 | |
---|
172 | plugin_log("looking for models in: %s" % plugins_dir) |
---|
173 | # compile_file(plugins_dir) #always recompile the folder plugin |
---|
174 | logger.info("plugin model dir: %s", plugins_dir) |
---|
175 | |
---|
176 | plugins = {} |
---|
177 | for filename in os.listdir(plugins_dir): |
---|
178 | name, ext = os.path.splitext(filename) |
---|
179 | if ext == '.py' and not name == '__init__': |
---|
180 | path = os.path.abspath(os.path.join(plugins_dir, filename)) |
---|
181 | try: |
---|
182 | model = load_custom_model(path) |
---|
183 | # TODO: add [plug-in] tag to model name in sasview_model |
---|
184 | if not model.name.startswith(PLUGIN_NAME_BASE): |
---|
185 | model.name = PLUGIN_NAME_BASE + model.name |
---|
186 | plugins[model.name] = model |
---|
187 | except Exception: |
---|
188 | msg = traceback.format_exc() |
---|
189 | msg += "\nwhile accessing model in %r" % path |
---|
190 | plugin_log(msg) |
---|
191 | logger.warning("Failed to load plugin %r. See %s for details", |
---|
192 | path, PLUGIN_LOG) |
---|
193 | |
---|
194 | return plugins |
---|
195 | |
---|
196 | |
---|
197 | class ModelManagerBase(object): |
---|
198 | """ |
---|
199 | Base class for the model manager |
---|
200 | """ |
---|
201 | #: mutable dictionary of models, continually updated to reflect the |
---|
202 | #: current set of plugins |
---|
203 | model_dictionary = None # type: Dict[str, Model] |
---|
204 | #: constant list of standard models |
---|
205 | standard_models = None # type: Dict[str, Model] |
---|
206 | #: list of plugin models reset each time the plugin directory is queried |
---|
207 | plugin_models = None # type: Dict[str, Model] |
---|
208 | #: timestamp on the plugin directory at the last plugin update |
---|
209 | last_time_dir_modified = 0 # type: int |
---|
210 | |
---|
211 | def __init__(self): |
---|
212 | # the model dictionary is allocated at the start and updated to |
---|
213 | # reflect the current list of models. Be sure to clear it rather |
---|
214 | # than reassign to it. |
---|
215 | self.model_dictionary = {} |
---|
216 | |
---|
217 | #Build list automagically from sasmodels package |
---|
218 | self.standard_models = {model.name: model |
---|
219 | for model in load_standard_models()} |
---|
220 | # Look for plugins |
---|
221 | self.plugins_reset() |
---|
222 | |
---|
223 | def _is_plugin_dir_changed(self): |
---|
224 | """ |
---|
225 | check the last time the plugin dir has changed and return true |
---|
226 | is the directory was modified else return false |
---|
227 | """ |
---|
228 | is_modified = False |
---|
229 | plugin_dir = find_plugins_dir() |
---|
230 | if os.path.isdir(plugin_dir): |
---|
231 | mod_time = os.path.getmtime(plugin_dir) |
---|
232 | if self.last_time_dir_modified != mod_time: |
---|
233 | is_modified = True |
---|
234 | self.last_time_dir_modified = mod_time |
---|
235 | |
---|
236 | return is_modified |
---|
237 | |
---|
238 | def composable_models(self): |
---|
239 | """ |
---|
240 | return list of standard models that can be used in sum/multiply |
---|
241 | """ |
---|
242 | # TODO: should scan plugin models in addition to standard models |
---|
243 | # and update model_editor so that it doesn't add plugins to the list |
---|
244 | return [model.name for model in self.standard_models.values() |
---|
245 | if not model.is_multiplicity_model] |
---|
246 | |
---|
247 | def plugins_update(self): |
---|
248 | """ |
---|
249 | return a dictionary of model if |
---|
250 | new models were added else return empty dictionary |
---|
251 | """ |
---|
252 | return self.plugins_reset() |
---|
253 | #if self._is_plugin_dir_changed(): |
---|
254 | # return self.plugins_reset() |
---|
255 | #else: |
---|
256 | # return {} |
---|
257 | |
---|
258 | def plugins_reset(self): |
---|
259 | """ |
---|
260 | return a dictionary of model |
---|
261 | """ |
---|
262 | self.plugin_models = find_plugin_models() |
---|
263 | self.model_dictionary.clear() |
---|
264 | self.model_dictionary.update(self.standard_models) |
---|
265 | self.model_dictionary.update(self.plugin_models) |
---|
266 | return self.get_model_list() |
---|
267 | |
---|
268 | def get_model_list(self): |
---|
269 | """ |
---|
270 | return dictionary of classified models |
---|
271 | |
---|
272 | *Structure Factors* are the structure factor models |
---|
273 | *Multi-Functions* are the multiplicity models |
---|
274 | *Plugin Models* are the plugin models |
---|
275 | |
---|
276 | Note that a model can be both a plugin and a structure factor or |
---|
277 | multiplicity model. |
---|
278 | """ |
---|
279 | ## Model_list now only contains attribute lists not category list. |
---|
280 | ## Eventually this should be in one master list -- read in category |
---|
281 | ## list then pull those models that exist and get attributes then add |
---|
282 | ## to list ..and if model does not exist remove from list as now |
---|
283 | ## and update json file. |
---|
284 | ## |
---|
285 | ## -PDB April 26, 2014 |
---|
286 | |
---|
287 | |
---|
288 | # Classify models |
---|
289 | structure_factors = [] |
---|
290 | form_factors = [] |
---|
291 | multiplicity_models = [] |
---|
292 | for model in self.model_dictionary.values(): |
---|
293 | # Old style models don't have is_structure_factor attribute |
---|
294 | if getattr(model, 'is_structure_factor', False): |
---|
295 | structure_factors.append(model) |
---|
296 | if getattr(model, 'is_form_factor', False): |
---|
297 | form_factors.append(model) |
---|
298 | if model.is_multiplicity_model: |
---|
299 | multiplicity_models.append(model) |
---|
300 | plugin_models = list(self.plugin_models.values()) |
---|
301 | |
---|
302 | return { |
---|
303 | "Structure Factors": structure_factors, |
---|
304 | "Form Factors": form_factors, |
---|
305 | "Plugin Models": plugin_models, |
---|
306 | "Multi-Functions": multiplicity_models, |
---|
307 | } |
---|
308 | |
---|
309 | |
---|
310 | class ModelManager(object): |
---|
311 | """ |
---|
312 | manage the list of available models |
---|
313 | """ |
---|
314 | base = None # type: ModelManagerBase() |
---|
315 | |
---|
316 | def __init__(self): |
---|
317 | if ModelManager.base is None: |
---|
318 | ModelManager.base = ModelManagerBase() |
---|
319 | |
---|
320 | def cat_model_list(self): |
---|
321 | return list(self.base.standard_models.values()) |
---|
322 | |
---|
323 | def update(self): |
---|
324 | return self.base.plugins_update() |
---|
325 | |
---|
326 | def plugins_reset(self): |
---|
327 | return self.base.plugins_reset() |
---|
328 | |
---|
329 | def get_model_list(self): |
---|
330 | return self.base.get_model_list() |
---|
331 | |
---|
332 | def composable_models(self): |
---|
333 | return self.base.composable_models() |
---|
334 | |
---|
335 | def get_model_dictionary(self): |
---|
336 | return self.base.model_dictionary |
---|