source: sasview/src/sas/sascalc/pr/fit/expression.py

ESS_GUI
Last change on this file was aed159f, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 5 years ago

Minor corrections to Inversion after PK's CR

  • Property mode set to 100644
File size: 13.6 KB
Line 
1from __future__ import print_function
2
3# This program is public domain
4"""
5Parameter expression evaluator.
6
7For systems in which constraints are expressed as string expressions rather
8than python code, :func:`compile_constraints` can construct an expression
9evaluator that substitutes the computed values of the expressions into the
10parameters.
11
12The compiler requires a symbol table, an expression set and a context.
13The symbol table maps strings containing fully qualified names such as
14'M1.c[3].full_width' to parameter objects with a 'value' property that
15can be queried and set.  The expression set maps symbol names from the
16symbol table to string expressions.  The context provides additional symbols
17for the expressions in addition to the usual mathematical functions and
18constants.
19
20The expressions are compiled and interpreted by python, with only minimal
21effort to make sure that they don't contain bad code.  The resulting
22constraints function returns 0 so it can be used directly in a fit problem
23definition.
24
25Extracting the symbol table from the model depends on the structure of the
26model.  If fitness.parameters() is set correctly, then this should simply
27be a matter of walking the parameter data, remembering the path to each
28parameter in the symbol table.  For compactness, dictionary elements should
29be referenced by .name rather than ["name"].  Model name can be used as the
30top level.
31
32Getting the parameter expressions applied correctly is challenging.
33The following monkey patch works by overriding model_update in FitProblem
34so that after setp(p) is called and, the constraints expression can be
35applied before telling the underlying fitness function that the model
36is out of date::
37
38        # Override model update so that parameter constraints are applied
39        problem._model_update = problem.model_update
40        def model_update():
41            constraints()
42            problem._model_update()
43        problem.model_update = model_update
44
45Ideally, this interface will change
46"""
47import math
48import re
49
50# simple pattern which matches symbols.  Note that it will also match
51# invalid substrings such as a3...9, but given syntactically correct
52# input it will only match symbols.
53_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
54
55def _symbols(expr,symtab):
56    """
57    Given an expression string and a symbol table, return the set of symbols
58    used in the expression.  Symbols are only returned once even if they
59    occur multiple times.  The return value is a set with the elements in
60    no particular order.
61
62    This is the first step in computing a dependency graph.
63    """
64    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
65    return set([symtab[m] for m in matches if m in symtab])
66
67def _substitute(expr,mapping):
68    """
69    Replace all occurrences of symbol s with mapping[s] for s in mapping.
70    """
71    # Find the symbols and the mapping
72    matches = [(m.start(),m.end(),mapping[m.group(1)])
73               for m in _symbol_pattern.finditer(expr)
74               if m.group(1) in mapping]
75
76    # Split the expression in to pieces, with new symbols replacing old
77    pieces = []
78    offset = 0
79    for start,end,text in matches:
80        pieces += [expr[offset:start],text]
81        offset = end
82    pieces.append(expr[offset:])
83
84    # Join the pieces and return them
85    return "".join(pieces)
86
87def _find_dependencies(symtab, exprs):
88    """
89    Returns a list of pair-wise dependencies from the parameter expressions.
90
91    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
92    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
93    such as p4 = 2*pi, this should return [(p4, None)]
94    """
95    deps = [(target,source)
96            for target,expr in exprs.items()
97            for source in _symbols_or_none(expr,symtab)]
98    return deps
99
100# Hack to deal with expressions without dependencies --- return a fake
101# dependency of None.
102# The better solution is fix order_dependencies so that it takes a
103# dictionary of {symbol: dependency_list}, for which no dependencies
104# is simply []; fix in parameter_mapping as well
105def _symbols_or_none(expr,symtab):
106    syms = _symbols(expr,symtab)
107    return syms if len(syms) else [None]
108
109def _parameter_mapping(pairs):
110    """
111    Find the parameter substitution we need so that expressions can
112    be evaluated without having to traverse a chain of
113    model.layer.parameter.value
114    """
115    left,right = zip(*pairs)
116    pars = list(sorted(p for p in set(left+right) if p is not None))
117    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
118    # p is None when there is an expression with no dependencies
119    substitution = dict( (p,'P%d.value'%i)
120                    for i,p in enumerate(sorted(pars))
121                    if p is not None)
122    return definition, substitution
123
124def no_constraints():
125    """
126    This parameter set has no constraints between the parameters.
127    """
128    pass
129
130def compile_constraints(symtab, exprs, context={}):
131    """
132    Build and return a function to evaluate all parameter expressions in
133    the proper order.
134
135    Input:
136
137        *symtab* is the symbol table for the model: { 'name': parameter }
138
139        *exprs* is the set of computed symbols: { 'name': 'expression' }
140
141        *context* is any additional context needed to evaluate the expression
142
143    Return:
144
145        updater function which sets parameter.value for each expression
146
147    Raises:
148
149       AssertionError - model, parameter or function is missing
150
151       SyntaxError - improper expression syntax
152
153       ValueError - expressions have circular dependencies
154
155    This function is not terribly sophisticated, and it would be easy to
156    trick.  However it handles the common cases cleanly and generates
157    reasonable messages for the common errors.
158
159    This code has not been fully audited for security.  While we have
160    removed the builtins and the ability to import modules, there may
161    be other vectors for users to perform more than simple function
162    evaluations.  Unauthenticated users should not be running this code.
163
164    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
165
166    Both names are provided for inverse functions, e.g., acos and arccos.
167
168    Should try running the function to identify syntax errors before
169    running it in a fit.
170
171    Use help(fn) to see the code generated for the returned function fn.
172    dis.dis(fn) will show the corresponding python vm instructions.
173    """
174
175    # Sort the parameters in the order they need to be evaluated
176    deps = _find_dependencies(symtab, exprs)
177    if deps == []: return no_constraints
178    order = order_dependencies(deps)
179
180
181    # Rather than using the full path to the parameters in the parameter
182    # expressions, instead use Pn, and substitute Pn.value for each occurrence
183    # of the parameter in the expression.
184    names = list(sorted(symtab.keys()))
185    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
186    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
187
188
189    # Initialize dictionary with available functions
190    globals = {}
191    globals.update(math.__dict__)
192    globals.update(dict(arcsin=math.asin,arccos=math.acos,
193                        arctan=math.atan,arctan2=math.atan2))
194    globals.update(context)
195    globals.update(parameters)
196    globals['id'] = id
197    locals = {}
198
199    # Define the constraints function
200    assignments = ["=".join((p,exprs[p])) for p in order]
201    code = [_substitute(s, mapping) for s in assignments]
202    functiondef = """
203def eval_expressions():
204    '''
205    %s
206    '''
207    %s
208    return 0
209"""%("\n    ".join(assignments),"\n    ".join(code))
210
211    exec(functiondef, globals, locals)
212    retfn = locals['eval_expressions']
213
214    # Remove garbage added to globals by exec
215    globals.pop('__doc__',None)
216    globals.pop('__name__',None)
217    globals.pop('__file__',None)
218    globals.pop('__builtins__')
219    #print globals.keys()
220
221    return retfn
222
223def order_dependencies(pairs):
224    """
225    Order elements from pairs so that b comes before a in the
226    ordered list for all pairs (a,b).
227    """
228    #print "order_dependencies",pairs
229    emptyset = set()
230    order = []
231
232    # Break pairs into left set and right set
233    left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[])
234    while pairs != []:
235        #print "within",pairs
236        # Find which items only occur on the right
237        independent = right - left
238        if independent == emptyset:
239            cycleset = ", ".join(str(s) for s in left)
240            raise ValueError("Cyclic dependencies amongst %s"%cycleset)
241
242        # The possibly resolvable items are those that depend on the independents
243        dependent = set([a for a,b in pairs if b in independent])
244        pairs = [(a,b) for a,b in pairs if b not in independent]
245        if pairs == []:
246            resolved = dependent
247        else:
248            left,right = [set(s) for s in zip(*pairs)]
249            resolved = dependent - left
250        #print "independent",independent,"dependent",dependent,"resolvable",resolved
251        order += resolved
252        #print "new order",order
253    order.reverse()
254    return order
255
256# ========= Test code ========
257def _check(msg,pairs):
258    """
259    Verify that the list n contains the given items, and that the list
260    satisfies the partial ordering given by the pairs in partial order.
261    """
262    left,right = zip(*pairs) if pairs != [] else ([],[])
263    items = set(left)
264    n = order_dependencies(pairs)
265    if set(n) != items or len(n) != len(items):
266        n.sort()
267        items = list(items); items.sort()
268        raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs))
269    for lo,hi in pairs:
270        if lo in n and hi in n and n.index(lo) >= n.index(hi):
271            raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs))
272
273def test_deps():
274    import numpy as np
275
276    # Null case
277    _check("test empty",[])
278
279    # Some dependencies
280    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
281    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
282    _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
283
284    # No dependencies
285    _check("test2",[(4,1),(3,2),(8,4)])
286
287    # Cycle test
288    pairs = [(1,4),(4,3),(4,5),(5,1)]
289    try:
290        n = order_dependencies(pairs)
291    except ValueError:
292        pass
293    else:
294        raise Exception("test3 expect ValueError exception for %s"%(pairs,))
295
296    # large test for gross speed check
297    A = np.random.randint(4000,size=(1000,2))
298    A[:,1] += 4000  # Avoid cycles
299    _check("test-large",A)
300
301    # depth tests
302    k = 200
303    A = np.array([range(0,k),range(1,k+1)]).T
304    _check("depth-1",A)
305
306    A = np.array([range(1,k+1),range(0,k)]).T
307    _check("depth-2",A)
308
309def test_expr():
310    import inspect, dis
311    import math
312
313    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
314    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
315
316    # Check symbol lookup
317    assert _symbols(expr, symtab) == set([1,2,3])
318
319    # Check symbol rename
320    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
321    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
322
323
324    # Check dependency builder
325    # Fake parameter class
326    class Parameter:
327        def __init__(self, name, value=0, expression=''):
328            self.path = name
329            self.value = value
330            self.expression = expression
331        def iscomputed(self): return (self.expression != '')
332        def __repr__(self): return self.path
333    def world(*pars):
334        symtab = dict((p.path,p) for p in pars)
335        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
336        return symtab, exprs
337    p1 = Parameter('G0.sigma',5)
338    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
339    p3 = Parameter('M1.G1',6)
340    p4 = Parameter('constant',expression='2*pi*35')
341    # Simple chain
342    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
343    # Constant expression
344    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
345    # No dependencies
346    assert set(_find_dependencies(*world(p1,p3))) == set([])
347
348    # Check function builder
349    fn = compile_constraints(*world(p1,p2,p3))
350
351    # Inspect the resulting function
352    if 0:
353        print(inspect.getdoc(fn))
354        print(dis.dis(fn))
355
356    # Evaluate the function and see if it updates the
357    # target value as expected
358    fn()
359    expected = 2*math.pi*math.sin(5/.1875) + 6
360    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
361
362    # Check empty dependency set doesn't crash
363    fn = compile_constraints(*world(p1,p3))
364    fn()
365
366    # Check that constants are evaluated properly
367    fn = compile_constraints(*world(p4))
368    fn()
369    assert p4.value == 2*math.pi*35
370
371    # Check additional context example; this also tests multiple
372    # expressions
373    class Table:
374        Si = 2.09
375        values = {'Si': 2.07}
376    tbl = Table()
377    p5 = Parameter('lookup',expression="tbl.Si")
378    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
379    fn()
380    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
381    p5.expression = "tbl.values['Si']"
382    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
383    fn()
384    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
385
386
387    # Verify that we capture invalid expressions
388    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2',
389                 'piddle',
390                 '5; import sys; print "p0wned"',
391                 '__import__("sys").argv']:
392        try:
393            p6 = Parameter('broken',expression=expr)
394            fn = compile_constraints(*world(p6))
395            fn()
396        except Exception as msg:
397            #print(msg)
398            pass
399        else:
400            raise "Failed to raise error for %s"%expr
401
402if __name__ == "__main__":
403    test_expr()
404    test_deps()
Note: See TracBrowser for help on using the repository browser.