source: sasview/src/sas/sascalc/fit/expression.py @ cb64d86

magnetic_scattrelease-4.2.2ticket-1009ticket-1249
Last change on this file since cb64d86 was 574adc7, checked in by Paul Kienzle <pkienzle@…>, 2 years ago

convert sascalc to python 2/3 syntax

  • Property mode set to 100644
File size: 13.7 KB
Line 
1# This program is public domain
2"""
3Parameter expression evaluator.
4
5For systems in which constraints are expressed as string expressions rather
6than python code, :func:`compile_constraints` can construct an expression
7evaluator that substitutes the computed values of the expressions into the
8parameters.
9
10The compiler requires a symbol table, an expression set and a context.
11The symbol table maps strings containing fully qualified names such as
12'M1.c[3].full_width' to parameter objects with a 'value' property that
13can be queried and set.  The expression set maps symbol names from the
14symbol table to string expressions.  The context provides additional symbols
15for the expressions in addition to the usual mathematical functions and
16constants.
17
18The expressions are compiled and interpreted by python, with only minimal
19effort to make sure that they don't contain bad code.  The resulting
20constraints function returns 0 so it can be used directly in a fit problem
21definition.
22
23Extracting the symbol table from the model depends on the structure of the
24model.  If fitness.parameters() is set correctly, then this should simply
25be a matter of walking the parameter data, remembering the path to each
26parameter in the symbol table.  For compactness, dictionary elements should
27be referenced by .name rather than ["name"].  Model name can be used as the
28top level.
29
30Getting the parameter expressions applied correctly is challenging.
31The following monkey patch works by overriding model_update in FitProblem
32so that after setp(p) is called and, the constraints expression can be
33applied before telling the underlying fitness function that the model
34is out of date::
35
36        # Override model update so that parameter constraints are applied
37        problem._model_update = problem.model_update
38        def model_update():
39            constraints()
40            problem._model_update()
41        problem.model_update = model_update
42
43Ideally, this interface will change
44"""
45from __future__ import print_function
46
47import math
48import re
49
50# simple pattern which matches symbols.  Note that it will also match
51# invalid substrings such as a3...9, but given syntactically correct
52# input it will only match symbols.
53_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
54
55def _symbols(expr,symtab):
56    """
57    Given an expression string and a symbol table, return the set of symbols
58    used in the expression.  Symbols are only returned once even if they
59    occur multiple times.  The return value is a set with the elements in
60    no particular order.
61
62    This is the first step in computing a dependency graph.
63    """
64    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
65    return set([symtab[m] for m in matches if m in symtab])
66
67def _substitute(expr,mapping):
68    """
69    Replace all occurrences of symbol s with mapping[s] for s in mapping.
70    """
71    # Find the symbols and the mapping
72    matches = [(m.start(),m.end(),mapping[m.group(1)])
73               for m in _symbol_pattern.finditer(expr)
74               if m.group(1) in mapping]
75
76    # Split the expression in to pieces, with new symbols replacing old
77    pieces = []
78    offset = 0
79    for start,end,text in matches:
80        pieces += [expr[offset:start],text]
81        offset = end
82    pieces.append(expr[offset:])
83
84    # Join the pieces and return them
85    return "".join(pieces)
86
87def _find_dependencies(symtab, exprs):
88    """
89    Returns a list of pair-wise dependencies from the parameter expressions.
90
91    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
92    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
93    such as p4 = 2*pi, this should return [(p4, None)]
94    """
95    deps = [(target,source)
96            for target,expr in exprs.items()
97            for source in _symbols_or_none(expr,symtab)]
98    return deps
99
100# Hack to deal with expressions without dependencies --- return a fake
101# dependency of None.
102# The better solution is fix order_dependencies so that it takes a
103# dictionary of {symbol: dependency_list}, for which no dependencies
104# is simply []; fix in parameter_mapping as well
105def _symbols_or_none(expr,symtab):
106    syms = _symbols(expr,symtab)
107    return syms if len(syms) else [None]
108
109def _parameter_mapping(pairs):
110    """
111    Find the parameter substitution we need so that expressions can
112    be evaluated without having to traverse a chain of
113    model.layer.parameter.value
114    """
115    left,right = zip(*pairs)
116    pars = list(sorted(p for p in set(left+right) if p is not None))
117    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
118    # p is None when there is an expression with no dependencies
119    substitution = dict( (p,'P%d.value'%i)
120                    for i,p in enumerate(sorted(pars))
121                    if p is not None)
122    return definition, substitution
123
124def no_constraints():
125    """
126    This parameter set has no constraints between the parameters.
127    """
128    pass
129
130def compile_constraints(symtab, exprs, context={}):
131    """
132    Build and return a function to evaluate all parameter expressions in
133    the proper order.
134
135    Input:
136
137        *symtab* is the symbol table for the model: { 'name': parameter }
138
139        *exprs* is the set of computed symbols: { 'name': 'expression' }
140
141        *context* is any additional context needed to evaluate the expression
142
143    Return:
144
145        updater function which sets parameter.value for each expression
146
147    Raises:
148
149       AssertionError - model, parameter or function is missing
150
151       SyntaxError - improper expression syntax
152
153       ValueError - expressions have circular dependencies
154
155    This function is not terribly sophisticated, and it would be easy to
156    trick.  However it handles the common cases cleanly and generates
157    reasonable messages for the common errors.
158
159    This code has not been fully audited for security.  While we have
160    removed the builtins and the ability to import modules, there may
161    be other vectors for users to perform more than simple function
162    evaluations.  Unauthenticated users should not be running this code.
163
164    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
165
166    Both names are provided for inverse functions, e.g., acos and arccos.
167
168    Should try running the function to identify syntax errors before
169    running it in a fit.
170
171    Use help(fn) to see the code generated for the returned function fn.
172    dis.dis(fn) will show the corresponding python vm instructions.
173    """
174
175    # Sort the parameters in the order they need to be evaluated
176    deps = _find_dependencies(symtab, exprs)
177    if deps == []: return no_constraints
178    order = order_dependencies(deps)
179
180
181    # Rather than using the full path to the parameters in the parameter
182    # expressions, instead use Pn, and substitute Pn.value for each occurrence
183    # of the parameter in the expression.
184    names = list(sorted(symtab.keys()))
185    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
186    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
187
188
189    # Initialize dictionary with available functions
190    globals = {}
191    globals.update(math.__dict__)
192    globals.update(dict(arcsin=math.asin,arccos=math.acos,
193                        arctan=math.atan,arctan2=math.atan2))
194    globals.update(context)
195    globals.update(parameters)
196    globals['id'] = id
197    locals = {}
198
199    # Define the constraints function
200    assignments = ["=".join((p,exprs[p])) for p in order]
201    code = [_substitute(s, mapping) for s in assignments]
202    functiondef = """
203def eval_expressions():
204    '''
205    %s
206    '''
207    %s
208    return 0
209"""%("\n    ".join(assignments),"\n    ".join(code))
210
211    #print("Function: "+functiondef)
212    exec functiondef in globals,locals
213    retfn = locals['eval_expressions']
214
215    # Remove garbage added to globals by exec
216    globals.pop('__doc__',None)
217    globals.pop('__name__',None)
218    globals.pop('__file__',None)
219    globals.pop('__builtins__')
220    #print globals.keys()
221
222    return retfn
223
224def order_dependencies(pairs):
225    """
226    Order elements from pairs so that b comes before a in the
227    ordered list for all pairs (a,b).
228    """
229    #print "order_dependencies",pairs
230    emptyset = set()
231    order = []
232
233    # Break pairs into left set and right set
234    left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[])
235    while pairs != []:
236        #print "within",pairs
237        # Find which items only occur on the right
238        independent = right - left
239        if independent == emptyset:
240            cycleset = ", ".join(str(s) for s in left)
241            raise ValueError("Cyclic dependencies amongst %s"%cycleset)
242
243        # The possibly resolvable items are those that depend on the independents
244        dependent = set([a for a,b in pairs if b in independent])
245        pairs = [(a,b) for a,b in pairs if b not in independent]
246        if pairs == []:
247            resolved = dependent
248        else:
249            left,right = [set(s) for s in zip(*pairs)]
250            resolved = dependent - left
251        #print "independent",independent,"dependent",dependent,"resolvable",resolved
252        order += resolved
253        #print "new order",order
254    order.reverse()
255    return order
256
257# ========= Test code ========
258def _check(msg,pairs):
259    """
260    Verify that the list n contains the given items, and that the list
261    satisfies the partial ordering given by the pairs in partial order.
262    """
263    left,right = zip(*pairs) if pairs != [] else ([],[])
264    items = set(left)
265    n = order_dependencies(pairs)
266    if set(n) != items or len(n) != len(items):
267        n.sort()
268        items = list(items); items.sort()
269        raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs))
270    for lo,hi in pairs:
271        if lo in n and hi in n and n.index(lo) >= n.index(hi):
272            raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs))
273
274def test_deps():
275    import numpy as np
276
277    # Null case
278    _check("test empty",[])
279
280    # Some dependencies
281    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
282    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
283    _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
284
285    # No dependencies
286    _check("test2",[(4,1),(3,2),(8,4)])
287
288    # Cycle test
289    pairs = [(1,4),(4,3),(4,5),(5,1)]
290    try:
291        n = order_dependencies(pairs)
292    except ValueError:
293        pass
294    else:
295        raise ValueError("test3 expect ValueError exception for %s"%(pairs,))
296
297    # large test for gross speed check
298    A = np.random.randint(4000,size=(1000,2))
299    A[:,1] += 4000  # Avoid cycles
300    _check("test-large",A)
301
302    # depth tests
303    k = 200
304    A = np.array([range(0,k),range(1,k+1)]).T
305    _check("depth-1",A)
306
307    A = np.array([range(1,k+1),range(0,k)]).T
308    _check("depth-2",A)
309
310def test_expr():
311    import inspect, dis
312    import math
313
314    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
315    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
316
317    # Check symbol lookup
318    assert _symbols(expr, symtab) == set([1,2,3])
319
320    # Check symbol rename
321    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
322    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
323
324
325    # Check dependency builder
326    # Fake parameter class
327    class Parameter:
328        def __init__(self, name, value=0, expression=''):
329            self.path = name
330            self.value = value
331            self.expression = expression
332        def iscomputed(self): return (self.expression != '')
333        def __repr__(self): return self.path
334    def world(*pars):
335        symtab = dict((p.path,p) for p in pars)
336        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
337        return symtab, exprs
338    p1 = Parameter('G0.sigma',5)
339    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
340    p3 = Parameter('M1.G1',6)
341    p4 = Parameter('constant',expression='2*pi*35')
342    # Simple chain
343    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
344    # Constant expression
345    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
346    # No dependencies
347    assert set(_find_dependencies(*world(p1,p3))) == set([])
348
349    # Check function builder
350    fn = compile_constraints(*world(p1,p2,p3))
351
352    # Inspect the resulting function
353    if 0:
354        print(inspect.getdoc(fn))
355        print(dis.dis(fn))
356
357    # Evaluate the function and see if it updates the
358    # target value as expected
359    fn()
360    expected = 2*math.pi*math.sin(5/.1875) + 6
361    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
362
363    # Check empty dependency set doesn't crash
364    fn = compile_constraints(*world(p1,p3))
365    fn()
366
367    # Check that constants are evaluated properly
368    fn = compile_constraints(*world(p4))
369    fn()
370    assert p4.value == 2*math.pi*35
371
372    # Check additional context example; this also tests multiple
373    # expressions
374    class Table:
375        Si = 2.09
376        values = {'Si': 2.07}
377    tbl = Table()
378    p5 = Parameter('lookup',expression="tbl.Si")
379    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
380    fn()
381    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
382    p5.expression = "tbl.values['Si']"
383    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
384    fn()
385    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
386
387
388    # Verify that we capture invalid expressions
389    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2',
390                 'piddle',
391                 '5; import sys; print "p0wned"',
392                 '__import__("sys").argv']:
393        try:
394            p6 = Parameter('broken',expression=expr)
395            fn = compile_constraints(*world(p6))
396            fn()
397        except Exception as msg:
398            #print(msg)
399            pass
400        else:
401            raise "Failed to raise error for %s"%expr
402
403if __name__ == "__main__":
404    test_expr()
405    test_deps()
Note: See TracBrowser for help on using the repository browser.