source: sasview/src/sas/sascalc/fit/expression.py @ ab0b93f

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalcmagnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since ab0b93f was a1b8fee, checked in by andyfaff, 8 years ago

MAINT: from future import print_function

  • Property mode set to 100644
File size: 13.7 KB
Line 
1# This program is public domain
2"""
3Parameter expression evaluator.
4
5For systems in which constraints are expressed as string expressions rather
6than python code, :func:`compile_constraints` can construct an expression
7evaluator that substitutes the computed values of the expressions into the
8parameters.
9
10The compiler requires a symbol table, an expression set and a context.
11The symbol table maps strings containing fully qualified names such as
12'M1.c[3].full_width' to parameter objects with a 'value' property that
13can be queried and set.  The expression set maps symbol names from the
14symbol table to string expressions.  The context provides additional symbols
15for the expressions in addition to the usual mathematical functions and
16constants.
17
18The expressions are compiled and interpreted by python, with only minimal
19effort to make sure that they don't contain bad code.  The resulting
20constraints function returns 0 so it can be used directly in a fit problem
21definition.
22
23Extracting the symbol table from the model depends on the structure of the
24model.  If fitness.parameters() is set correctly, then this should simply
25be a matter of walking the parameter data, remembering the path to each
26parameter in the symbol table.  For compactness, dictionary elements should
27be referenced by .name rather than ["name"].  Model name can be used as the
28top level.
29
30Getting the parameter expressions applied correctly is challenging.
31The following monkey patch works by overriding model_update in FitProblem
32so that after setp(p) is called and, the constraints expression can be
33applied before telling the underlying fitness function that the model
34is out of date::
35
36        # Override model update so that parameter constraints are applied
37        problem._model_update = problem.model_update
38        def model_update():
39            constraints()
40            problem._model_update()
41        problem.model_update = model_update
42
43Ideally, this interface will change
44"""
45from __future__ import print_function
46
47import math
48import re
49
50# simple pattern which matches symbols.  Note that it will also match
51# invalid substrings such as a3...9, but given syntactically correct
52# input it will only match symbols.
53_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
54
55def _symbols(expr,symtab):
56    """
57    Given an expression string and a symbol table, return the set of symbols
58    used in the expression.  Symbols are only returned once even if they
59    occur multiple times.  The return value is a set with the elements in
60    no particular order.
61   
62    This is the first step in computing a dependency graph.
63    """
64    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
65    return set([symtab[m] for m in matches if m in symtab])
66
67def _substitute(expr,mapping):
68    """
69    Replace all occurrences of symbol s with mapping[s] for s in mapping.
70    """
71    # Find the symbols and the mapping
72    matches = [(m.start(),m.end(),mapping[m.group(1)])
73               for m in _symbol_pattern.finditer(expr)
74               if m.group(1) in mapping]
75
76    # Split the expression in to pieces, with new symbols replacing old
77    pieces = []
78    offset = 0
79    for start,end,text in matches:
80        pieces += [expr[offset:start],text]
81        offset = end
82    pieces.append(expr[offset:])
83   
84    # Join the pieces and return them
85    return "".join(pieces)
86
87def _find_dependencies(symtab, exprs):
88    """
89    Returns a list of pair-wise dependencies from the parameter expressions.
90   
91    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
92    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
93    such as p4 = 2*pi, this should return [(p4, None)]
94    """
95    deps = [(target,source)
96            for target,expr in exprs.items()
97            for source in _symbols_or_none(expr,symtab)]
98    return deps
99
100# Hack to deal with expressions without dependencies --- return a fake
101# dependency of None.
102# The better solution is fix order_dependencies so that it takes a
103# dictionary of {symbol: dependency_list}, for which no dependencies
104# is simply []; fix in parameter_mapping as well
105def _symbols_or_none(expr,symtab):
106    syms = _symbols(expr,symtab)
107    return syms if len(syms) else [None]
108
109def _parameter_mapping(pairs):
110    """
111    Find the parameter substitution we need so that expressions can
112    be evaluated without having to traverse a chain of
113    model.layer.parameter.value
114    """
115    left,right = zip(*pairs)
116    pars = list(sorted(p for p in set(left+right) if p is not None))
117    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
118    # p is None when there is an expression with no dependencies
119    substitution = dict( (p,'P%d.value'%i)
120                    for i,p in enumerate(sorted(pars))
121                    if p is not None)
122    return definition, substitution
123
124def no_constraints(): 
125    """
126    This parameter set has no constraints between the parameters.
127    """
128    pass
129
130def compile_constraints(symtab, exprs, context={}):
131    """
132    Build and return a function to evaluate all parameter expressions in
133    the proper order.
134
135    Input:
136
137        *symtab* is the symbol table for the model: { 'name': parameter }
138
139        *exprs* is the set of computed symbols: { 'name': 'expression' }
140
141        *context* is any additional context needed to evaluate the expression
142
143    Return:
144
145        updater function which sets parameter.value for each expression
146
147    Raises:
148
149       AssertionError - model, parameter or function is missing
150
151       SyntaxError - improper expression syntax
152
153       ValueError - expressions have circular dependencies
154
155    This function is not terribly sophisticated, and it would be easy to
156    trick.  However it handles the common cases cleanly and generates
157    reasonable messages for the common errors.
158
159    This code has not been fully audited for security.  While we have
160    removed the builtins and the ability to import modules, there may
161    be other vectors for users to perform more than simple function
162    evaluations.  Unauthenticated users should not be running this code.
163
164    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
165   
166    Both names are provided for inverse functions, e.g., acos and arccos.
167
168    Should try running the function to identify syntax errors before
169    running it in a fit.
170   
171    Use help(fn) to see the code generated for the returned function fn.
172    dis.dis(fn) will show the corresponding python vm instructions.
173    """
174
175    # Sort the parameters in the order they need to be evaluated
176    deps = _find_dependencies(symtab, exprs)
177    if deps == []: return no_constraints
178    order = order_dependencies(deps)
179
180
181    # Rather than using the full path to the parameters in the parameter
182    # expressions, instead use Pn, and substitute Pn.value for each occurrence
183    # of the parameter in the expression.
184    names = list(sorted(symtab.keys()))
185    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
186    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
187
188
189    # Initialize dictionary with available functions
190    globals = {}
191    globals.update(math.__dict__)
192    globals.update(dict(arcsin=math.asin,arccos=math.acos,
193                        arctan=math.atan,arctan2=math.atan2))
194    globals.update(context)
195    globals.update(parameters)
196    globals['id'] = id
197    locals = {}
198
199    # Define the constraints function
200    assignments = ["=".join((p,exprs[p])) for p in order]
201    code = [_substitute(s, mapping) for s in assignments]
202    functiondef = """
203def eval_expressions():
204    '''
205    %s
206    '''
207    %s
208    return 0
209"""%("\n    ".join(assignments),"\n    ".join(code))
210
211    #print("Function: "+functiondef)
212    exec functiondef in globals,locals
213    retfn = locals['eval_expressions']
214
215    # Remove garbage added to globals by exec
216    globals.pop('__doc__',None)
217    globals.pop('__name__',None)
218    globals.pop('__file__',None)
219    globals.pop('__builtins__')
220    #print globals.keys()
221
222    return retfn
223
224def order_dependencies(pairs):
225    """
226    Order elements from pairs so that b comes before a in the
227    ordered list for all pairs (a,b).
228    """
229    #print "order_dependencies",pairs
230    emptyset = set()
231    order = []
232
233    # Break pairs into left set and right set
234    left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[])
235    while pairs != []:
236        #print "within",pairs
237        # Find which items only occur on the right
238        independent = right - left
239        if independent == emptyset:
240            cycleset = ", ".join(str(s) for s in left)
241            raise ValueError,"Cyclic dependencies amongst %s"%cycleset
242
243        # The possibly resolvable items are those that depend on the independents
244        dependent = set([a for a,b in pairs if b in independent])
245        pairs = [(a,b) for a,b in pairs if b not in independent]
246        if pairs == []:
247            resolved = dependent
248        else:
249            left,right = [set(s) for s in zip(*pairs)]
250            resolved = dependent - left
251        #print "independent",independent,"dependent",dependent,"resolvable",resolved
252        order += resolved
253        #print "new order",order
254    order.reverse()
255    return order
256
257# ========= Test code ========
258def _check(msg,pairs):
259    """
260    Verify that the list n contains the given items, and that the list
261    satisfies the partial ordering given by the pairs in partial order.
262    """
263    left,right = zip(*pairs) if pairs != [] else ([],[])
264    items = set(left)
265    n = order_dependencies(pairs)
266    if set(n) != items or len(n) != len(items):
267        n.sort()
268        items = list(items); items.sort()
269        raise Exception,"%s expect %s to contain %s for %s"%(msg,n,items,pairs)
270    for lo,hi in pairs:
271        if lo in n and hi in n and n.index(lo) >= n.index(hi):
272            raise Exception,"%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs)
273
274def test_deps():
275    import numpy as np
276
277    # Null case
278    _check("test empty",[])
279
280    # Some dependencies
281    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
282    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
283    _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
284
285    # No dependencies
286    _check("test2",[(4,1),(3,2),(8,4)])
287
288    # Cycle test
289    pairs = [(1,4),(4,3),(4,5),(5,1)]
290    try: n = order_dependencies(pairs)
291    except ValueError: pass
292    else: raise Exception,"test3 expect ValueError exception for %s"%(pairs,)
293
294    # large test for gross speed check
295    A = np.random.randint(4000,size=(1000,2))
296    A[:,1] += 4000  # Avoid cycles
297    _check("test-large",A)
298
299    # depth tests
300    k = 200
301    A = np.array([range(0,k),range(1,k+1)]).T
302    _check("depth-1",A)
303
304    A = np.array([range(1,k+1),range(0,k)]).T
305    _check("depth-2",A)
306
307def test_expr():
308    import inspect, dis
309    import math
310   
311    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
312    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
313   
314    # Check symbol lookup
315    assert _symbols(expr, symtab) == set([1,2,3])
316
317    # Check symbol rename
318    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
319    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
320
321
322    # Check dependency builder
323    # Fake parameter class
324    class Parameter:
325        def __init__(self, name, value=0, expression=''):
326            self.path = name
327            self.value = value
328            self.expression = expression
329        def iscomputed(self): return (self.expression != '')
330        def __repr__(self): return self.path
331    def world(*pars):
332        symtab = dict((p.path,p) for p in pars)
333        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
334        return symtab, exprs
335    p1 = Parameter('G0.sigma',5)
336    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
337    p3 = Parameter('M1.G1',6)
338    p4 = Parameter('constant',expression='2*pi*35')
339    # Simple chain
340    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
341    # Constant expression
342    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
343    # No dependencies
344    assert set(_find_dependencies(*world(p1,p3))) == set([])
345
346    # Check function builder
347    fn = compile_constraints(*world(p1,p2,p3))
348
349    # Inspect the resulting function
350    if 0:
351        print(inspect.getdoc(fn))
352        print(dis.dis(fn))
353
354    # Evaluate the function and see if it updates the
355    # target value as expected
356    fn()
357    expected = 2*math.pi*math.sin(5/.1875) + 6
358    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
359   
360    # Check empty dependency set doesn't crash
361    fn = compile_constraints(*world(p1,p3))
362    fn()
363
364    # Check that constants are evaluated properly
365    fn = compile_constraints(*world(p4))
366    fn()
367    assert p4.value == 2*math.pi*35
368
369    # Check additional context example; this also tests multiple
370    # expressions
371    class Table:
372        Si = 2.09
373        values = {'Si': 2.07}
374    tbl = Table()
375    p5 = Parameter('lookup',expression="tbl.Si")
376    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
377    fn()
378    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
379    p5.expression = "tbl.values['Si']"
380    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
381    fn()
382    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
383   
384
385    # Verify that we capture invalid expressions
386    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 
387                 'piddle',
388                 '5; import sys; print "p0wned"',
389                 '__import__("sys").argv']:
390        try:
391            p6 = Parameter('broken',expression=expr)
392            fn = compile_constraints(*world(p6))
393            fn()
394        except Exception as msg:
395            #print(msg)
396            pass
397        else:
398            raise "Failed to raise error for %s"%expr
399
400if __name__ == "__main__":
401    test_expr()
402    test_deps()
Note: See TracBrowser for help on using the repository browser.