source: sasview/src/sas/sascalc/fit/expression.py @ 45d90b9

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 45d90b9 was b699768, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 9 years ago

Initial commit of the refactored SasCalc? module.

  • Property mode set to 100644
File size: 13.6 KB
RevLine 
[4e9f227]1# This program is public domain
2"""
3Parameter expression evaluator.
4
5For systems in which constraints are expressed as string expressions rather
6than python code, :func:`compile_constraints` can construct an expression
7evaluator that substitutes the computed values of the expressions into the
8parameters.
9
10The compiler requires a symbol table, an expression set and a context.
11The symbol table maps strings containing fully qualified names such as
12'M1.c[3].full_width' to parameter objects with a 'value' property that
13can be queried and set.  The expression set maps symbol names from the
14symbol table to string expressions.  The context provides additional symbols
15for the expressions in addition to the usual mathematical functions and
16constants.
17
18The expressions are compiled and interpreted by python, with only minimal
19effort to make sure that they don't contain bad code.  The resulting
20constraints function returns 0 so it can be used directly in a fit problem
21definition.
22
23Extracting the symbol table from the model depends on the structure of the
24model.  If fitness.parameters() is set correctly, then this should simply
25be a matter of walking the parameter data, remembering the path to each
26parameter in the symbol table.  For compactness, dictionary elements should
27be referenced by .name rather than ["name"].  Model name can be used as the
28top level.
29
30Getting the parameter expressions applied correctly is challenging.
31The following monkey patch works by overriding model_update in FitProblem
32so that after setp(p) is called and, the constraints expression can be
33applied before telling the underlying fitness function that the model
34is out of date::
35
36        # Override model update so that parameter constraints are applied
37        problem._model_update = problem.model_update
38        def model_update():
39            constraints()
40            problem._model_update()
41        problem.model_update = model_update
42
43Ideally, this interface will change
44"""
45import math
46import re
47
48# simple pattern which matches symbols.  Note that it will also match
49# invalid substrings such as a3...9, but given syntactically correct
50# input it will only match symbols.
51_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)')
52
53def _symbols(expr,symtab):
54    """
55    Given an expression string and a symbol table, return the set of symbols
56    used in the expression.  Symbols are only returned once even if they
57    occur multiple times.  The return value is a set with the elements in
58    no particular order.
59   
60    This is the first step in computing a dependency graph.
61    """
62    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
63    return set([symtab[m] for m in matches if m in symtab])
64
65def _substitute(expr,mapping):
66    """
67    Replace all occurrences of symbol s with mapping[s] for s in mapping.
68    """
69    # Find the symbols and the mapping
70    matches = [(m.start(),m.end(),mapping[m.group(1)])
71               for m in _symbol_pattern.finditer(expr)
72               if m.group(1) in mapping]
73
74    # Split the expression in to pieces, with new symbols replacing old
75    pieces = []
76    offset = 0
77    for start,end,text in matches:
78        pieces += [expr[offset:start],text]
79        offset = end
80    pieces.append(expr[offset:])
81   
82    # Join the pieces and return them
83    return "".join(pieces)
84
85def _find_dependencies(symtab, exprs):
86    """
87    Returns a list of pair-wise dependencies from the parameter expressions.
88   
89    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
90    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
91    such as p4 = 2*pi, this should return [(p4, None)]
92    """
93    deps = [(target,source)
94            for target,expr in exprs.items()
95            for source in _symbols_or_none(expr,symtab)]
96    return deps
97
98# Hack to deal with expressions without dependencies --- return a fake
99# dependency of None.
100# The better solution is fix order_dependencies so that it takes a
101# dictionary of {symbol: dependency_list}, for which no dependencies
102# is simply []; fix in parameter_mapping as well
103def _symbols_or_none(expr,symtab):
104    syms = _symbols(expr,symtab)
105    return syms if len(syms) else [None]
106
107def _parameter_mapping(pairs):
108    """
109    Find the parameter substitution we need so that expressions can
110    be evaluated without having to traverse a chain of
111    model.layer.parameter.value
112    """
113    left,right = zip(*pairs)
114    pars = list(sorted(p for p in set(left+right) if p is not None))
115    definition = dict( ('P%d'%i,p)  for i,p in enumerate(pars) )
116    # p is None when there is an expression with no dependencies
117    substitution = dict( (p,'P%d.value'%i)
118                    for i,p in enumerate(sorted(pars))
119                    if p is not None)
120    return definition, substitution
121
122def no_constraints(): 
123    """
124    This parameter set has no constraints between the parameters.
125    """
126    pass
127
128def compile_constraints(symtab, exprs, context={}):
129    """
130    Build and return a function to evaluate all parameter expressions in
131    the proper order.
132
133    Input:
134
135        *symtab* is the symbol table for the model: { 'name': parameter }
136
137        *exprs* is the set of computed symbols: { 'name': 'expression' }
138
139        *context* is any additional context needed to evaluate the expression
140
141    Return:
142
143        updater function which sets parameter.value for each expression
144
145    Raises:
146
147       AssertionError - model, parameter or function is missing
148
149       SyntaxError - improper expression syntax
150
151       ValueError - expressions have circular dependencies
152
153    This function is not terribly sophisticated, and it would be easy to
154    trick.  However it handles the common cases cleanly and generates
155    reasonable messages for the common errors.
156
157    This code has not been fully audited for security.  While we have
158    removed the builtins and the ability to import modules, there may
159    be other vectors for users to perform more than simple function
160    evaluations.  Unauthenticated users should not be running this code.
161
162    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
163   
164    Both names are provided for inverse functions, e.g., acos and arccos.
165
166    Should try running the function to identify syntax errors before
167    running it in a fit.
168   
169    Use help(fn) to see the code generated for the returned function fn.
170    dis.dis(fn) will show the corresponding python vm instructions.
171    """
172
173    # Sort the parameters in the order they need to be evaluated
174    deps = _find_dependencies(symtab, exprs)
175    if deps == []: return no_constraints
176    order = order_dependencies(deps)
177
178
179    # Rather than using the full path to the parameters in the parameter
180    # expressions, instead use Pn, and substitute Pn.value for each occurrence
181    # of the parameter in the expression.
182    names = list(sorted(symtab.keys()))
183    parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names))
184    mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names))
185
186
187    # Initialize dictionary with available functions
188    globals = {}
189    globals.update(math.__dict__)
190    globals.update(dict(arcsin=math.asin,arccos=math.acos,
191                        arctan=math.atan,arctan2=math.atan2))
192    globals.update(context)
193    globals.update(parameters)
194    globals['id'] = id
195    locals = {}
196
197    # Define the constraints function
198    assignments = ["=".join((p,exprs[p])) for p in order]
199    code = [_substitute(s, mapping) for s in assignments]
200    functiondef = """
201def eval_expressions():
202    '''
203    %s
204    '''
205    %s
206    return 0
207"""%("\n    ".join(assignments),"\n    ".join(code))
208
209    #print("Function: "+functiondef)
210    exec functiondef in globals,locals
211    retfn = locals['eval_expressions']
212
213    # Remove garbage added to globals by exec
214    globals.pop('__doc__',None)
215    globals.pop('__name__',None)
216    globals.pop('__file__',None)
217    globals.pop('__builtins__')
218    #print globals.keys()
219
220    return retfn
221
222def order_dependencies(pairs):
223    """
224    Order elements from pairs so that b comes before a in the
225    ordered list for all pairs (a,b).
226    """
227    #print "order_dependencies",pairs
228    emptyset = set()
229    order = []
230
231    # Break pairs into left set and right set
232    left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[])
233    while pairs != []:
234        #print "within",pairs
235        # Find which items only occur on the right
236        independent = right - left
237        if independent == emptyset:
238            cycleset = ", ".join(str(s) for s in left)
239            raise ValueError,"Cyclic dependencies amongst %s"%cycleset
240
241        # The possibly resolvable items are those that depend on the independents
242        dependent = set([a for a,b in pairs if b in independent])
243        pairs = [(a,b) for a,b in pairs if b not in independent]
244        if pairs == []:
245            resolved = dependent
246        else:
247            left,right = [set(s) for s in zip(*pairs)]
248            resolved = dependent - left
249        #print "independent",independent,"dependent",dependent,"resolvable",resolved
250        order += resolved
251        #print "new order",order
252    order.reverse()
253    return order
254
255# ========= Test code ========
256def _check(msg,pairs):
257    """
258    Verify that the list n contains the given items, and that the list
259    satisfies the partial ordering given by the pairs in partial order.
260    """
261    left,right = zip(*pairs) if pairs != [] else ([],[])
262    items = set(left)
263    n = order_dependencies(pairs)
264    if set(n) != items or len(n) != len(items):
265        n.sort()
266        items = list(items); items.sort()
267        raise Exception,"%s expect %s to contain %s for %s"%(msg,n,items,pairs)
268    for lo,hi in pairs:
269        if lo in n and hi in n and n.index(lo) >= n.index(hi):
270            raise Exception,"%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs)
271
272def test_deps():
273    import numpy
274
275    # Null case
276    _check("test empty",[])
277
278    # Some dependencies
279    _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])
280    _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)])
281    _check("test1 numpy",numpy.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]))
282
283    # No dependencies
284    _check("test2",[(4,1),(3,2),(8,4)])
285
286    # Cycle test
287    pairs = [(1,4),(4,3),(4,5),(5,1)]
288    try: n = order_dependencies(pairs)
289    except ValueError: pass
290    else: raise Exception,"test3 expect ValueError exception for %s"%(pairs,)
291
292    # large test for gross speed check
293    A = numpy.random.randint(4000,size=(1000,2))
294    A[:,1] += 4000  # Avoid cycles
295    _check("test-large",A)
296
297    # depth tests
298    k = 200
299    A = numpy.array([range(0,k),range(1,k+1)]).T
300    _check("depth-1",A)
301
302    A = numpy.array([range(1,k+1),range(0,k)]).T
303    _check("depth-2",A)
304
305def test_expr():
306    import inspect, dis
307    import math
308   
309    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
310    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
311   
312    # Check symbol lookup
313    assert _symbols(expr, symtab) == set([1,2,3])
314
315    # Check symbol rename
316    assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
317    assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
318
319
320    # Check dependency builder
321    # Fake parameter class
322    class Parameter:
323        def __init__(self, name, value=0, expression=''):
324            self.path = name
325            self.value = value
326            self.expression = expression
327        def iscomputed(self): return (self.expression != '')
328        def __repr__(self): return self.path
329    def world(*pars):
330        symtab = dict((p.path,p) for p in pars)
331        exprs = dict((p.path,p.expression) for p in pars if p.iscomputed())
332        return symtab, exprs
333    p1 = Parameter('G0.sigma',5)
334    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
335    p3 = Parameter('M1.G1',6)
336    p4 = Parameter('constant',expression='2*pi*35')
337    # Simple chain
338    assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)])
339    # Constant expression
340    assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)])
341    # No dependencies
342    assert set(_find_dependencies(*world(p1,p3))) == set([])
343
344    # Check function builder
345    fn = compile_constraints(*world(p1,p2,p3))
346
347    # Inspect the resulting function
348    if 0:
349        print(inspect.getdoc(fn))
350        print(dis.dis(fn))
351
352    # Evaluate the function and see if it updates the
353    # target value as expected
354    fn()
355    expected = 2*math.pi*math.sin(5/.1875) + 6
356    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
357   
358    # Check empty dependency set doesn't crash
359    fn = compile_constraints(*world(p1,p3))
360    fn()
361
362    # Check that constants are evaluated properly
363    fn = compile_constraints(*world(p4))
364    fn()
365    assert p4.value == 2*math.pi*35
366
367    # Check additional context example; this also tests multiple
368    # expressions
369    class Table:
370        Si = 2.09
371        values = {'Si': 2.07}
372    tbl = Table()
373    p5 = Parameter('lookup',expression="tbl.Si")
374    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
375    fn()
376    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
377    p5.expression = "tbl.values['Si']"
378    fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl))
379    fn()
380    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
381   
382
383    # Verify that we capture invalid expressions
384    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 
385                 'piddle',
386                 '5; import sys; print "p0wned"',
387                 '__import__("sys").argv']:
388        try:
389            p6 = Parameter('broken',expression=expr)
390            fn = compile_constraints(*world(p6))
391            fn()
392        except Exception as msg:
393            #print(msg)
394            pass
395        else:
396            raise "Failed to raise error for %s"%expr
397
398if __name__ == "__main__":
399    test_expr()
400    test_deps()
Note: See TracBrowser for help on using the repository browser.