source: sasview/src/sas/sascalc/fit/expression.py @ 99321b2

Last change on this file since 99321b2 was 3c680c1, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

maybe fix doc build error on jenkins mac

  • Property mode set to 100644
File size: 14.2 KB
Line 
1# This program is public domain
2"""
3Parameter expression evaluator.
4
5For systems in which constraints are expressed as string expressions rather
6than python code, :func:`compile_constraints` can construct an expression
7evaluator that substitutes the computed values of the expressions into the
8parameters.
9
10The compiler requires a symbol table, an expression set and a context.
11The symbol table maps strings containing fully qualified names such as
12'M1.c[3].full_width' to parameter objects with a 'value' property that
13can be queried and set.  The expression set maps symbol names from the
14symbol table to string expressions.  The context provides additional symbols
15for the expressions in addition to the usual mathematical functions and
16constants.
17
18The expressions are compiled and interpreted by python, with only minimal
19effort to make sure that they don't contain bad code.  The resulting
20constraints function returns 0 so it can be used directly in a fit problem
21definition.
22
23Extracting the symbol table from the model depends on the structure of the
24model.  If fitness.parameters() is set correctly, then this should simply
25be a matter of walking the parameter data, remembering the path to each
26parameter in the symbol table.  For compactness, dictionary elements should
27be referenced by .name rather than ["name"].  Model name can be used as the
28top level.
29
30Getting the parameter expressions applied correctly is challenging.
31The following monkey patch works by overriding model_update in FitProblem
32so that after setp(p) is called and, the constraints expression can be
33applied before telling the underlying fitness function that the model
34is out of date::
35
36        # Override model update so that parameter constraints are applied
37        problem._model_update = problem.model_update
38        def model_update():
39            constraints()
40            problem._model_update()
41        problem.model_update = model_update
42
43Ideally, this interface will change
44"""
45from __future__ import print_function
46
47import math
48import re
49
50# simple pattern which matches symbols.  Note that it will also match
51# invalid substrings such as a3...9, but given syntactically correct
52# input it will only match symbols.
53_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
54
55def _symbols(expr,symtab):
56    """
57    Given an expression string and a symbol table, return the set of symbols
58    used in the expression.  Symbols are only returned once even if they
59    occur multiple times.  The return value is a set with the elements in
60    no particular order.
61
62    This is the first step in computing a dependency graph.
63    """
64    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
65    return set([symtab[m] for m in matches if m in symtab])
66
67def _substitute(expr,mapping):
68    """
69    Replace all occurrences of symbol s with mapping[s] for s in mapping.
70    """
71    # Find the symbols and the mapping
72    matches = [(m.start(),m.end(),mapping[m.group(1)])
73               for m in _symbol_pattern.finditer(expr)
74               if m.group(1) in mapping]
75
76    # Split the expression in to pieces, with new symbols replacing old
77    pieces = []
78    offset = 0
79    for start,end,text in matches:
80        pieces += [expr[offset:start],text]
81        offset = end
82    pieces.append(expr[offset:])
83
84    # Join the pieces and return them
85    return "".join(pieces)
86
87def _find_dependencies(symtab, exprs):
88    """
89    Returns a list of pair-wise dependencies from the parameter expressions.
90
91    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
92    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
93    such as p4 = 2*pi, this should return [(p4, None)]
94    """
95    deps = [(target,source)
96            for target,expr in exprs.items()
97            for source in _symbols_or_none(expr,symtab)]
98    return deps
99
100# Hack to deal with expressions without dependencies --- return a fake
101# dependency of None.
102# The better solution is fix order_dependencies so that it takes a
103# dictionary of {symbol: dependency_list}, for which no dependencies
104# is simply []; fix in parameter_mapping as well
105def _symbols_or_none(expr,symtab):
106    syms = _symbols(expr,symtab)
107    return syms if len(syms) else [None]
108
109def _parameter_mapping(pairs):
110    """
111    Find the parameter substitution we need so that expressions can
112    be evaluated without having to traverse a chain of
113    model.layer.parameter.value
114    """
115    left,right = zip(*pairs)
116    pars = list(sorted(p for p in set(left+right) if p is not None))
117    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
118    # p is None when there is an expression with no dependencies
119    substitution = dict( (p,'P%d.value'%i)
120                    for i,p in enumerate(sorted(pars))
121                    if p is not None)
122    return definition, substitution
123
124def no_constraints():
125    """
126    This parameter set has no constraints between the parameters.
127    """
128    pass
129
130def compile_constraints(symtab, exprs, context={}):
131    """
132    Build and return a function to evaluate all parameter expressions in
133    the proper order.
134
135    Input:
136
137        *symtab* is the symbol table for the model: { 'name': parameter }
138
139        *exprs* is the set of computed symbols: { 'name': 'expression' }
140
141        *context* is any additional context needed to evaluate the expression
142
143    Return:
144
145        updater function which sets parameter.value for each expression
146
147    Raises:
148
149       AssertionError - model, parameter or function is missing
150
151       SyntaxError - improper expression syntax
152
153       ValueError - expressions have circular dependencies
154
155    This function is not terribly sophisticated, and it would be easy to
156    trick.  However it handles the common cases cleanly and generates
157    reasonable messages for the common errors.
158
159    This code has not been fully audited for security.  While we have
160    removed the builtins and the ability to import modules, there may
161    be other vectors for users to perform more than simple function
162    evaluations.  Unauthenticated users should not be running this code.
163
164    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
165
166    Both names are provided for inverse functions, e.g., acos and arccos.
167
168    Should try running the function to identify syntax errors before
169    running it in a fit.
170
171    Use help(fn) to see the code generated for the returned function fn.
172    dis.dis(fn) will show the corresponding python vm instructions.
173    """
174
175    # Sort the parameters in the order they need to be evaluated
176    deps = _find_dependencies(symtab, exprs)
177    if deps == []: return no_constraints
178    order = order_dependencies(deps)
179
180
181    # Rather than using the full path to the parameters in the parameter
182    # expressions, instead use Pn, and substitute Pn.value for each occurrence
183    # of the parameter in the expression.
184    names = list(sorted(symtab.keys()))
185    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
186    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
187
188
189    # Initialize dictionary with available functions
190    global_context = {}
191    global_context.update(math.__dict__)
192    global_context.update(dict(arcsin=math.asin,arccos=math.acos,
193                               arctan=math.atan,arctan2=math.atan2))
194    global_context.update(context)
195    global_context.update(parameters)
196    global_context['id'] = id
197    local_context = {}
198
199    # Define the constraints function
200    assignments = ["=".join((p,exprs[p])) for p in order]
201    code = [_substitute(s, mapping) for s in assignments]
202    functiondef = """
203def eval_expressions():
204    '''
205    %s
206    '''
207    %s
208    return 0
209"""%("\n    ".join(assignments),"\n    ".join(code))
210
211    #print("Function: "+functiondef)
212    # CRUFT: python < 3.0;  doc builder isn't allowing the following exec
213    # https://stackoverflow.com/questions/4484872/why-doesnt-exec-work-in-a-function-with-a-subfunction/41368813#comment73790496_41368813
214    #exec(functiondef, global_context, local_context)
215    eval(compile(functiondef, '<string>', 'exec'), global_context, local_context)
216    retfn = local_context['eval_expressions']
217
218    # Remove garbage added to globals by exec
219    global_context.pop('__doc__', None)
220    global_context.pop('__name__', None)
221    global_context.pop('__file__', None)
222    global_context.pop('__builtins__')
223    #print globals.keys()
224
225    return retfn
226
227def order_dependencies(pairs):
228    """
229    Order elements from pairs so that b comes before a in the
230    ordered list for all pairs (a,b).
231    """
232    #print "order_dependencies",pairs
233    emptyset = set()
234    order = []
235
236    # Break pairs into left set and right set
237    # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty.
238    left,right = [set(s) for s in zip(*pairs)] if len(pairs) > 0 else ([],[])
239    while len(pairs) > 0:
240        #print "within",pairs
241        # Find which items only occur on the right
242        independent = right - left
243        if independent == emptyset:
244            cycleset = ", ".join(str(s) for s in left)
245            raise ValueError("Cyclic dependencies amongst %s"%cycleset)
246
247        # The possibly resolvable items are those that depend on the independents
248        dependent = set([a for a,b in pairs if b in independent])
249        pairs = [(a,b) for a,b in pairs if b not in independent]
250        if pairs == []:
251            resolved = dependent
252        else:
253            left,right = [set(s) for s in zip(*pairs)]
254            resolved = dependent - left
255        #print "independent",independent,"dependent",dependent,"resolvable",resolved
256        order += resolved
257        #print "new order",order
258    order.reverse()
259    return order
260
261# ========= Test code ========
262def _check(msg,pairs):
263    """
264    Verify that the list n contains the given items, and that the list
265    satisfies the partial ordering given by the pairs in partial order.
266    """
267    # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty.
268    left,right = zip(*pairs) if len(pairs) > 0 else ([],[])
269    items = set(left)
270    n = order_dependencies(pairs)
271    if set(n) != items or len(n) != len(items):
272        n.sort()
273        items = list(items); items.sort()
274        raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs))
275    for lo,hi in pairs:
276        if lo in n and hi in n and n.index(lo) >= n.index(hi):
277            raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs))
278
279def test_deps():
280    import numpy as np
281
282    # Null case
283    _check("test empty",[])
284
285    # Some dependencies
286    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
287    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
288    _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
289
290    # No dependencies
291    _check("test2",[(4,1),(3,2),(8,4)])
292
293    # Cycle test
294    pairs = [(1,4),(4,3),(4,5),(5,1)]
295    try:
296        n = order_dependencies(pairs)
297    except ValueError:
298        pass
299    else:
300        raise ValueError("test3 expect ValueError exception for %s"%(pairs,))
301
302    # large test for gross speed check
303    A = np.random.randint(4000,size=(1000,2))
304    A[:,1] += 4000  # Avoid cycles
305    _check("test-large",A)
306
307    # depth tests
308    k = 200
309    A = np.array([range(0,k),range(1,k+1)]).T
310    _check("depth-1",A)
311
312    A = np.array([range(1,k+1),range(0,k)]).T
313    _check("depth-2",A)
314
315def test_expr():
316    import inspect, dis
317    import math
318
319    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
320    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
321
322    # Check symbol lookup
323    assert _symbols(expr, symtab) == set([1,2,3])
324
325    # Check symbol rename
326    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
327    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
328
329
330    # Check dependency builder
331    # Fake parameter class
332    class Parameter:
333        def __init__(self, name, value=0, expression=''):
334            self.path = name
335            self.value = value
336            self.expression = expression
337        def iscomputed(self): return (self.expression != '')
338        def __repr__(self): return self.path
339    def world(*pars):
340        symtab = dict((p.path,p) for p in pars)
341        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
342        return symtab, exprs
343    p1 = Parameter('G0.sigma',5)
344    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
345    p3 = Parameter('M1.G1',6)
346    p4 = Parameter('constant',expression='2*pi*35')
347    # Simple chain
348    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
349    # Constant expression
350    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
351    # No dependencies
352    assert set(_find_dependencies(*world(p1,p3))) == set([])
353
354    # Check function builder
355    fn = compile_constraints(*world(p1,p2,p3))
356
357    # Inspect the resulting function
358    if 0:
359        print(inspect.getdoc(fn))
360        print(dis.dis(fn))
361
362    # Evaluate the function and see if it updates the
363    # target value as expected
364    fn()
365    expected = 2*math.pi*math.sin(5/.1875) + 6
366    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
367
368    # Check empty dependency set doesn't crash
369    fn = compile_constraints(*world(p1,p3))
370    fn()
371
372    # Check that constants are evaluated properly
373    fn = compile_constraints(*world(p4))
374    fn()
375    assert p4.value == 2*math.pi*35
376
377    # Check additional context example; this also tests multiple
378    # expressions
379    class Table:
380        Si = 2.09
381        values = {'Si': 2.07}
382    tbl = Table()
383    p5 = Parameter('lookup',expression="tbl.Si")
384    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
385    fn()
386    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
387    p5.expression = "tbl.values['Si']"
388    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
389    fn()
390    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
391
392
393    # Verify that we capture invalid expressions
394    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2',
395                 'piddle',
396                 '5; import sys; print "p0wned"',
397                 '__import__("sys").argv']:
398        try:
399            p6 = Parameter('broken',expression=expr)
400            fn = compile_constraints(*world(p6))
401            fn()
402        except Exception as msg:
403            #print(msg)
404            pass
405        else:
406            raise "Failed to raise error for %s"%expr
407
408if __name__ == "__main__":
409    test_expr()
410    test_deps()
Note: See TracBrowser for help on using the repository browser.