source: sasview/park-1.2.1/park/expression.py @ b4293d2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b4293d2 was 3570545, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

Adding park Part 2

  • Property mode set to 100644
File size: 8.3 KB
Line 
1# This program is public domain
2"""
3Functions for manipulating expressions.   
4"""
5import math
6import re
7from deps import order_dependencies
8
9# simple pattern which matches symbols.  Note that it will also match
10# invalid substrings such as a3...9, but given syntactically correct
11# input it will only match symbols.
12_symbol_pattern = re.compile('([a-zA-Z][a-zA-Z_0-9.]*)')
13
14def symbols(expr,symtab):
15    """
16    Given an expression string and a symbol table, return the set of symbols
17    used in the expression.  Symbols are only returned once even if they
18    occur multiple times.  The return value is a set with the elements in
19    no particular order.
20   
21    This is the first step in computing a dependency graph.
22    """
23    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
24    return set([symtab[m] for m in matches if m in symtab])
25
26def substitute(expr,mapping):
27    """
28    Replace all occurrences of symbol s with mapping[s] for s in mapping.
29    """
30    # Find the symbols and the mapping
31    matches = [(m.start(),m.end(),mapping[m.group(1)])
32               for m in _symbol_pattern.finditer(expr)
33               if m.group(1) in mapping]
34   
35    # Split the expression in to pieces, with new symbols replacing old
36    pieces = []
37    offset = 0
38    for start,end,text in matches:
39        pieces += [expr[offset:start],text]
40        offset = end
41    pieces.append(expr[offset:])
42   
43    # Join the pieces and return them
44    return "".join(pieces)
45
46def find_dependencies(pars):
47    """
48    Returns a list of pair-wise dependencies from the parameter expressions.
49   
50    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
51    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
52    such as p4 = 2*pi, this should return [(p4, None)]
53    """
54    symtab = dict([(p.path, p) for p in pars])
55    # Hack to deal with expressions without dependencies --- return a fake
56    # dependency of None. 
57    # The better solution is fix order_dependencies so that it takes a
58    # dictionary of {symbol: dependency_list}, for which no dependencies
59    # is simply []; fix in parameter_mapping as well
60    def symbols_or_none(expr,symtab):
61        syms = symbols(expr,symtab)
62        return syms if len(syms) else [None]
63    deps = [(p,dep) 
64            for p in pars if p.iscomputed()
65            for dep in symbols_or_none(p.expression,symtab)]
66    return deps
67
68def parameter_mapping(pairs):
69    """
70    Find the parameter substitution we need so that expressions can
71    be evaluated without having to traverse a chain of
72    model.layer.parameter.value
73    """
74    left,right = zip(*pairs)
75    pars = set(left+right)
76    symtab = dict( ('P%d'%i,p) for i,p in enumerate(pars) )
77    # p is None when there is an expression with no dependencies
78    mapping = dict( (p.path,'P%d.value'%i) 
79                    for i,p in enumerate(pars) 
80                    if p is not None)
81    return symtab,mapping
82
83def no_constraints(): 
84    """
85    This parameter set has no constraints between the parameters.
86    """
87    pass
88
89def build_eval(pars, context={}):
90    """
91    Build and return a function to evaluate all parameter expressions in
92    the proper order.
93   
94    Inputs:
95        pars is a list of parameters
96        context is a dictionary of additional symbols for the expressions
97
98    Output:
99        updater function
100
101    Raises:
102       AssertionError - model, parameter or function is missing
103       SyntaxError - improper expression syntax
104       ValueError - expressions have circular dependencies
105
106    This function is not terribly sophisticated, and it would be easy to
107    trick.  However it handles the common cases cleanly and generates
108    reasonable messages for the common errors.
109
110    This code has not been fully audited for security.  While we have
111    removed the builtins and the ability to import modules, there may
112    be other vectors for users to perform more than simple function
113    evaluations.  Unauthenticated users should not be running this code.
114
115    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
116   
117    The list of parameters is probably something like::
118   
119        parset.setprefix()
120        pars = parset.flatten()
121   
122    Note that math uses acos while numpy uses arccos.  To avoid confusion
123    we allow both.
124   
125    Should try running the function to identify syntax errors before
126    running it in a fit.
127   
128    Use help(fn) to see the code generated for the returned function fn.
129    dis.dis(fn) will show the corresponding python vm instructions.
130    """
131
132    # Initialize dictionary with available functions
133    globals = {}
134    globals.update(math.__dict__)
135    globals.update(dict(arcsin=math.asin,arccos=math.acos,
136                        arctan=math.atan,arctan2=math.atan2))
137    globals.update(context)
138
139    # Sort the parameters in the order they need to be evaluated
140    deps = find_dependencies(pars)
141    if deps == []: return no_constraints
142    par_table,par_mapping = parameter_mapping(deps)
143    order = order_dependencies(deps)
144   
145    # Finish setting up the global and local namespace
146    globals.update(par_table)
147    locals = {}
148
149    # Define the function body
150    exprs = [p.path+"="+p.expression for p in order]
151    code = [substitute(s,par_mapping) for s in exprs]
152       
153    # Define the constraints function
154    functiondef = """
155def eval_expressions():
156    '''
157    %s
158    '''
159    %s
160"""%("\n    ".join(exprs),"\n    ".join(code))
161
162    #print "Function:",function
163    exec functiondef in globals,locals
164    retfn = locals['eval_expressions']
165
166    # Remove garbage added to globals by exec
167    globals.pop('__doc__',None)
168    globals.pop('__name__',None)
169    globals.pop('__file__',None)
170    globals.pop('__builtins__')
171    #print globals.keys()
172
173    return retfn
174
175def test():
176    import inspect, dis
177    import math
178   
179    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
180    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
181   
182    # Check symbol lookup
183    assert symbols(expr, symtab) == set([1,2,3])
184
185    # Check symbol rename
186    assert substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
187
188
189    # Check dependency builder
190    # Fake parameter class
191    class Parameter:
192        def __init__(self, name, value=0, expression=''):
193            self.path = name
194            self.value = value
195            self.expression = expression
196        def iscomputed(self): return (self.expression != '')
197        def __repr__(self): return self.path
198    p1 = Parameter('G0.sigma',5)
199    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
200    p3 = Parameter('M1.G1',6)
201    p4 = Parameter('constant',expression='2*pi*35')
202    # Simple chain
203    assert set(find_dependencies([p1,p2,p3])) == set([(p2,p1),(p2,p3)])
204    # Constant expression
205    assert set(find_dependencies([p1,p4])) == set([(p4,None)])
206    # No dependencies
207    assert set(find_dependencies([p1,p3])) == set([])
208
209    # Check function builder
210    fn = build_eval([p1,p2,p3])
211
212    # Inspect the resulting function
213    if False:
214        print inspect.getdoc(fn)
215        print dis.dis(fn)
216
217    # Evaluate the function and see if it updates the
218    # target value as expected
219    fn()
220    expected = 2*math.pi*math.sin(5/.1875) + 6
221    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
222   
223    # Check empty dependency set doesn't crash
224    fn = build_eval([p1,p3])
225    fn()
226
227    # Check that constants are evaluated properly
228    fn = build_eval([p4])
229    fn()
230    assert p4.value == 2*math.pi*35
231
232    # Check additional context example; this also tests multiple
233    # expressions
234    class Table:
235        Si = 2.09
236        values = {'Si': 2.07}
237    tbl = Table()
238    p5 = Parameter('lookup',expression="tbl.Si")
239    fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
240    fn()
241    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
242    p5.expression = "tbl.values['Si']"
243    fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
244    fn()
245    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
246   
247
248    # Verify that we capture invalid expressions
249    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 
250                 'piddle',
251                 'import sys; print "p0wned"',
252                 '__import__("sys").argv']:
253        try:
254            p6 = Parameter('broken',expression=expr)
255            fn = build_eval([p6])
256            fn()
257        except Exception,msg: pass
258        else:  raise "Failed to raise error for %s"%expr
259
260if __name__ == "__main__": test()
Note: See TracBrowser for help on using the repository browser.