source: sasview/src/sas/sascalc/fit/expression.py @ 1309205b

Last change on this file since 1309205b was e4c475b7, checked in by Piotr Rozyczko <rozyczko@…>, 7 years ago

Minor fixes

  • Property mode set to 100644
File size: 13.7 KB
Line 
1# This program is public domain
2"""
3Parameter expression evaluator.
4
5For systems in which constraints are expressed as string expressions rather
6than python code, :func:`compile_constraints` can construct an expression
7evaluator that substitutes the computed values of the expressions into the
8parameters.
9
10The compiler requires a symbol table, an expression set and a context.
11The symbol table maps strings containing fully qualified names such as
12'M1.c[3].full_width' to parameter objects with a 'value' property that
13can be queried and set.  The expression set maps symbol names from the
14symbol table to string expressions.  The context provides additional symbols
15for the expressions in addition to the usual mathematical functions and
16constants.
17
18The expressions are compiled and interpreted by python, with only minimal
19effort to make sure that they don't contain bad code.  The resulting
20constraints function returns 0 so it can be used directly in a fit problem
21definition.
22
23Extracting the symbol table from the model depends on the structure of the
24model.  If fitness.parameters() is set correctly, then this should simply
25be a matter of walking the parameter data, remembering the path to each
26parameter in the symbol table.  For compactness, dictionary elements should
27be referenced by .name rather than ["name"].  Model name can be used as the
28top level.
29
30Getting the parameter expressions applied correctly is challenging.
31The following monkey patch works by overriding model_update in FitProblem
32so that after setp(p) is called and, the constraints expression can be
33applied before telling the underlying fitness function that the model
34is out of date::
35
36        # Override model update so that parameter constraints are applied
37        problem._model_update = problem.model_update
38        def model_update():
39            constraints()
40            problem._model_update()
41        problem.model_update = model_update
42
43Ideally, this interface will change
44"""
45from __future__ import print_function
46
47import math
48import re
49
50# simple pattern which matches symbols.  Note that it will also match
51# invalid substrings such as a3...9, but given syntactically correct
52# input it will only match symbols.
53_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
54
55def _symbols(expr,symtab):
56    """
57    Given an expression string and a symbol table, return the set of symbols
58    used in the expression.  Symbols are only returned once even if they
59    occur multiple times.  The return value is a set with the elements in
60    no particular order.
61
62    This is the first step in computing a dependency graph.
63    """
64    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
65    return set([symtab[m] for m in matches if m in symtab])
66
67def _substitute(expr,mapping):
68    """
69    Replace all occurrences of symbol s with mapping[s] for s in mapping.
70    """
71    # Find the symbols and the mapping
72    matches = [(m.start(),m.end(),mapping[m.group(1)])
73               for m in _symbol_pattern.finditer(expr)
74               if m.group(1) in mapping]
75
76    # Split the expression in to pieces, with new symbols replacing old
77    pieces = []
78    offset = 0
79    for start,end,text in matches:
80        pieces += [expr[offset:start],text]
81        offset = end
82    pieces.append(expr[offset:])
83
84    # Join the pieces and return them
85    return "".join(pieces)
86
87def _find_dependencies(symtab, exprs):
88    """
89    Returns a list of pair-wise dependencies from the parameter expressions.
90
91    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
92    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
93    such as p4 = 2*pi, this should return [(p4, None)]
94    """
95    deps = [(target,source)
96            for target,expr in exprs.items()
97            for source in _symbols_or_none(expr,symtab)]
98    return deps
99
100# Hack to deal with expressions without dependencies --- return a fake
101# dependency of None.
102# The better solution is fix order_dependencies so that it takes a
103# dictionary of {symbol: dependency_list}, for which no dependencies
104# is simply []; fix in parameter_mapping as well
105def _symbols_or_none(expr,symtab):
106    syms = _symbols(expr,symtab)
107    return syms if len(syms) else [None]
108
109def _parameter_mapping(pairs):
110    """
111    Find the parameter substitution we need so that expressions can
112    be evaluated without having to traverse a chain of
113    model.layer.parameter.value
114    """
115    left,right = zip(*pairs)
116    pars = list(sorted(p for p in set(left+right) if p is not None))
117    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
118    # p is None when there is an expression with no dependencies
119    substitution = dict( (p,'P%d.value'%i)
120                    for i,p in enumerate(sorted(pars))
121                    if p is not None)
122    return definition, substitution
123
124def no_constraints():
125    """
126    This parameter set has no constraints between the parameters.
127    """
128    pass
129
130def compile_constraints(symtab, exprs, context={}):
131    """
132    Build and return a function to evaluate all parameter expressions in
133    the proper order.
134
135    Input:
136
137        *symtab* is the symbol table for the model: { 'name': parameter }
138
139        *exprs* is the set of computed symbols: { 'name': 'expression' }
140
141        *context* is any additional context needed to evaluate the expression
142
143    Return:
144
145        updater function which sets parameter.value for each expression
146
147    Raises:
148
149       AssertionError - model, parameter or function is missing
150
151       SyntaxError - improper expression syntax
152
153       ValueError - expressions have circular dependencies
154
155    This function is not terribly sophisticated, and it would be easy to
156    trick.  However it handles the common cases cleanly and generates
157    reasonable messages for the common errors.
158
159    This code has not been fully audited for security.  While we have
160    removed the builtins and the ability to import modules, there may
161    be other vectors for users to perform more than simple function
162    evaluations.  Unauthenticated users should not be running this code.
163
164    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
165
166    Both names are provided for inverse functions, e.g., acos and arccos.
167
168    Should try running the function to identify syntax errors before
169    running it in a fit.
170
171    Use help(fn) to see the code generated for the returned function fn.
172    dis.dis(fn) will show the corresponding python vm instructions.
173    """
174
175    # Sort the parameters in the order they need to be evaluated
176    deps = _find_dependencies(symtab, exprs)
177    if deps == []: return no_constraints
178    order = order_dependencies(deps)
179
180
181    # Rather than using the full path to the parameters in the parameter
182    # expressions, instead use Pn, and substitute Pn.value for each occurrence
183    # of the parameter in the expression.
184    names = list(sorted(symtab.keys()))
185    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
186    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
187
188
189    # Initialize dictionary with available functions
190    globals = {}
191    globals.update(math.__dict__)
192    globals.update(dict(arcsin=math.asin,arccos=math.acos,
193                        arctan=math.atan,arctan2=math.atan2))
194    globals.update(context)
195    globals.update(parameters)
196    globals['id'] = id
197    locals = {}
198
199    # Define the constraints function
200    assignments = ["=".join((p,exprs[p])) for p in order]
201    code = [_substitute(s, mapping) for s in assignments]
202    functiondef = """
203def eval_expressions():
204    '''
205    %s
206    '''
207    %s
208    return 0
209"""%("\n    ".join(assignments),"\n    ".join(code))
210
211    #print("Function: "+functiondef)
212    # Python 2.7
213    #exec (functiondef in globals,locals)
214    # Python 3.5
215    exec (functiondef, globals, locals)
216
217    retfn = locals['eval_expressions']
218
219    # Remove garbage added to globals by exec
220    globals.pop('__doc__',None)
221    globals.pop('__name__',None)
222    globals.pop('__file__',None)
223    globals.pop('__builtins__')
224    #print globals.keys()
225
226    return retfn
227
228def order_dependencies(pairs):
229    """
230    Order elements from pairs so that b comes before a in the
231    ordered list for all pairs (a,b).
232    """
233    #print "order_dependencies",pairs
234    emptyset = set()
235    order = []
236
237    # Break pairs into left set and right set
238    left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[])
239    while pairs != []:
240        #print "within",pairs
241        # Find which items only occur on the right
242        independent = right - left
243        if independent == emptyset:
244            cycleset = ", ".join(str(s) for s in left)
245            raise ValueError("Cyclic dependencies amongst %s"%cycleset)
246
247        # The possibly resolvable items are those that depend on the independents
248        dependent = set([a for a,b in pairs if b in independent])
249        pairs = [(a,b) for a,b in pairs if b not in independent]
250        if pairs == []:
251            resolved = dependent
252        else:
253            left,right = [set(s) for s in zip(*pairs)]
254            resolved = dependent - left
255        #print "independent",independent,"dependent",dependent,"resolvable",resolved
256        order += resolved
257        #print "new order",order
258    order.reverse()
259    return order
260
261# ========= Test code ========
262def _check(msg,pairs):
263    """
264    Verify that the list n contains the given items, and that the list
265    satisfies the partial ordering given by the pairs in partial order.
266    """
267    left,right = zip(*pairs) if pairs != [] else ([],[])
268    items = set(left)
269    n = order_dependencies(pairs)
270    if set(n) != items or len(n) != len(items):
271        n.sort()
272        items = list(items); items.sort()
273        raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs))
274    for lo,hi in pairs:
275        if lo in n and hi in n and n.index(lo) >= n.index(hi):
276            raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs))
277
278def test_deps():
279    import numpy as np
280
281    # Null case
282    _check("test empty",[])
283
284    # Some dependencies
285    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
286    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
287    _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
288
289    # No dependencies
290    _check("test2",[(4,1),(3,2),(8,4)])
291
292    # Cycle test
293    pairs = [(1,4),(4,3),(4,5),(5,1)]
294    try:
295        n = order_dependencies(pairs)
296    except ValueError:
297        pass
298    else:
299        raise ValueError("test3 expect ValueError exception for %s"%(pairs,))
300
301    # large test for gross speed check
302    A = np.random.randint(4000,size=(1000,2))
303    A[:,1] += 4000  # Avoid cycles
304    _check("test-large",A)
305
306    # depth tests
307    k = 200
308    A = np.array([range(0,k),range(1,k+1)]).T
309    _check("depth-1",A)
310
311    A = np.array([range(1,k+1),range(0,k)]).T
312    _check("depth-2",A)
313
314def test_expr():
315    import inspect, dis
316    import math
317
318    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
319    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
320
321    # Check symbol lookup
322    assert _symbols(expr, symtab) == set([1,2,3])
323
324    # Check symbol rename
325    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
326    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
327
328
329    # Check dependency builder
330    # Fake parameter class
331    class Parameter:
332        def __init__(self, name, value=0, expression=''):
333            self.path = name
334            self.value = value
335            self.expression = expression
336        def iscomputed(self): return (self.expression != '')
337        def __repr__(self): return self.path
338    def world(*pars):
339        symtab = dict((p.path,p) for p in pars)
340        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
341        return symtab, exprs
342    p1 = Parameter('G0.sigma',5)
343    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
344    p3 = Parameter('M1.G1',6)
345    p4 = Parameter('constant',expression='2*pi*35')
346    # Simple chain
347    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
348    # Constant expression
349    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
350    # No dependencies
351    assert set(_find_dependencies(*world(p1,p3))) == set([])
352
353    # Check function builder
354    fn = compile_constraints(*world(p1,p2,p3))
355
356    # Inspect the resulting function
357    if 0:
358        print(inspect.getdoc(fn))
359        print(dis.dis(fn))
360
361    # Evaluate the function and see if it updates the
362    # target value as expected
363    fn()
364    expected = 2*math.pi*math.sin(5/.1875) + 6
365    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
366
367    # Check empty dependency set doesn't crash
368    fn = compile_constraints(*world(p1,p3))
369    fn()
370
371    # Check that constants are evaluated properly
372    fn = compile_constraints(*world(p4))
373    fn()
374    assert p4.value == 2*math.pi*35
375
376    # Check additional context example; this also tests multiple
377    # expressions
378    class Table:
379        Si = 2.09
380        values = {'Si': 2.07}
381    tbl = Table()
382    p5 = Parameter('lookup',expression="tbl.Si")
383    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
384    fn()
385    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
386    p5.expression = "tbl.values['Si']"
387    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
388    fn()
389    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
390
391
392    # Verify that we capture invalid expressions
393    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2',
394                 'piddle',
395                 '5; import sys; print "p0wned"',
396                 '__import__("sys").argv']:
397        try:
398            p6 = Parameter('broken',expression=expr)
399            fn = compile_constraints(*world(p6))
400            fn()
401        except Exception as msg:
402            #print(msg)
403            pass
404        else:
405            raise "Failed to raise error for %s"%expr
406
407if __name__ == "__main__":
408    test_expr()
409    test_deps()
Note: See TracBrowser for help on using the repository browser.