source: sasview/park-1.2.1/park/expression.py @ 59b1b92

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 59b1b92 was e3efa6b3, checked in by pkienzle, 11 years ago

restructure bumps wrapper and add levenberg-marquardt

  • Property mode set to 100644
File size: 8.4 KB
Line 
1# This program is public domain
2"""
3Functions for manipulating expressions.   
4"""
5import math
6import re
7from deps import order_dependencies
8
9# simple pattern which matches symbols.  Note that it will also match
10# invalid substrings such as a3...9, but given syntactically correct
11# input it will only match symbols.
12_symbol_pattern = re.compile('([a-zA-Z][a-zA-Z_0-9.]*)')
13
14def symbols(expr,symtab):
15    """
16    Given an expression string and a symbol table, return the set of symbols
17    used in the expression.  Symbols are only returned once even if they
18    occur multiple times.  The return value is a set with the elements in
19    no particular order.
20   
21    This is the first step in computing a dependency graph.
22    """
23    matches = [m.group(0) for m in _symbol_pattern.finditer(expr)]
24    return set([symtab[m] for m in matches if m in symtab])
25
26def substitute(expr,mapping):
27    """
28    Replace all occurrences of symbol s with mapping[s] for s in mapping.
29    """
30    # Find the symbols and the mapping
31    matches = [(m.start(),m.end(),mapping[m.group(1)])
32               for m in _symbol_pattern.finditer(expr)
33               if m.group(1) in mapping]
34   
35    # Split the expression in to pieces, with new symbols replacing old
36    pieces = []
37    offset = 0
38    for start,end,text in matches:
39        pieces += [expr[offset:start],text]
40        offset = end
41    pieces.append(expr[offset:])
42   
43    # Join the pieces and return them
44    return "".join(pieces)
45
46def find_dependencies(pars):
47    """
48    Returns a list of pair-wise dependencies from the parameter expressions.
49   
50    For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will
51    return [(p3,p1),(p3,p2)].  For base expressions without dependencies,
52    such as p4 = 2*pi, this should return [(p4, None)]
53    """
54    symtab = dict([(p.path, p) for p in pars])
55    # Hack to deal with expressions without dependencies --- return a fake
56    # dependency of None. 
57    # The better solution is fix order_dependencies so that it takes a
58    # dictionary of {symbol: dependency_list}, for which no dependencies
59    # is simply []; fix in parameter_mapping as well
60    def symbols_or_none(expr,symtab):
61        syms = symbols(expr,symtab)
62        return syms if len(syms) else [None]
63    deps = [(p,dep) 
64            for p in pars if p.iscomputed()
65            for dep in symbols_or_none(p.expression,symtab)]
66    return deps
67
68def parameter_mapping(pairs):
69    """
70    Find the parameter substitution we need so that expressions can
71    be evaluated without having to traverse a chain of
72    model.layer.parameter.value
73    """
74    left,right = zip(*pairs)
75    pars = set(left+right)
76    symtab = dict( ('P%d'%i,p) for i,p in enumerate(pars) )
77    # p is None when there is an expression with no dependencies
78    mapping = dict( (p.path,'P%d.value'%i) 
79                    for i,p in enumerate(pars) 
80                    if p is not None)
81    return symtab,mapping
82
83def no_constraints(): 
84    """
85    This parameter set has no constraints between the parameters.
86    """
87    pass
88
89def build_eval(pars, context={}):
90    """
91    Build and return a function to evaluate all parameter expressions in
92    the proper order.
93   
94    Inputs:
95        pars is a list of parameters
96        context is a dictionary of additional symbols for the expressions
97
98    Output:
99        updater function
100
101    Raises:
102       AssertionError - model, parameter or function is missing
103       SyntaxError - improper expression syntax
104       ValueError - expressions have circular dependencies
105
106    This function is not terribly sophisticated, and it would be easy to
107    trick.  However it handles the common cases cleanly and generates
108    reasonable messages for the common errors.
109
110    This code has not been fully audited for security.  While we have
111    removed the builtins and the ability to import modules, there may
112    be other vectors for users to perform more than simple function
113    evaluations.  Unauthenticated users should not be running this code.
114
115    Parameter names are assumed to contain only _.a-zA-Z0-9#[]
116   
117    The list of parameters is probably something like::
118   
119        parset.setprefix()
120        pars = parset.flatten()
121   
122    Note that math uses acos while numpy uses arccos.  To avoid confusion
123    we allow both.
124   
125    Should try running the function to identify syntax errors before
126    running it in a fit.
127   
128    Use help(fn) to see the code generated for the returned function fn.
129    dis.dis(fn) will show the corresponding python vm instructions.
130    """
131
132    # Initialize dictionary with available functions
133    globals = {}
134    globals.update(math.__dict__)
135    globals.update(dict(arcsin=math.asin,arccos=math.acos,
136                        arctan=math.atan,arctan2=math.atan2))
137    globals.update(context)
138
139    # Sort the parameters in the order they need to be evaluated
140    deps = find_dependencies(pars)
141    if deps == []: return no_constraints
142    par_table,par_mapping = parameter_mapping(deps)
143    order = order_dependencies(deps)
144   
145    # Finish setting up the global and local namespace
146    globals.update(par_table)
147    locals = {}
148
149    # Define the function body
150    exprs = [p.path+"="+p.expression for p in order]
151    code = [substitute(s,par_mapping) for s in exprs]
152       
153    # Define the constraints function
154    functiondef = """
155def eval_expressions():
156    '''
157    %s
158    '''
159    %s
160"""%("\n    ".join(exprs),"\n    ".join(code))
161
162    #print "Function:",functiondef
163    exec functiondef in globals,locals
164    retfn = locals['eval_expressions']
165
166    # Remove garbage added to globals by exec
167    globals.pop('__doc__',None)
168    globals.pop('__name__',None)
169    globals.pop('__file__',None)
170    globals.pop('__builtins__')
171    #print globals.keys()
172
173    return retfn
174
175def test():
176    import inspect, dis
177    import math
178   
179    symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4}
180    expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b'
181   
182    # Check symbol lookup
183    assert symbols(expr, symtab) == set([1,2,3])
184
185    # Check symbol rename
186    assert substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b'
187    assert substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q'
188
189
190    # Check dependency builder
191    # Fake parameter class
192    class Parameter:
193        def __init__(self, name, value=0, expression=''):
194            self.path = name
195            self.value = value
196            self.expression = expression
197        def iscomputed(self): return (self.expression != '')
198        def __repr__(self): return self.path
199    p1 = Parameter('G0.sigma',5)
200    p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1')
201    p3 = Parameter('M1.G1',6)
202    p4 = Parameter('constant',expression='2*pi*35')
203    # Simple chain
204    assert set(find_dependencies([p1,p2,p3])) == set([(p2,p1),(p2,p3)])
205    # Constant expression
206    assert set(find_dependencies([p1,p4])) == set([(p4,None)])
207    # No dependencies
208    assert set(find_dependencies([p1,p3])) == set([])
209
210    # Check function builder
211    fn = build_eval([p1,p2,p3])
212
213    # Inspect the resulting function
214    if False:
215        print inspect.getdoc(fn)
216        print dis.dis(fn)
217
218    # Evaluate the function and see if it updates the
219    # target value as expected
220    fn()
221    expected = 2*math.pi*math.sin(5/.1875) + 6
222    assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected)
223   
224    # Check empty dependency set doesn't crash
225    fn = build_eval([p1,p3])
226    fn()
227
228    # Check that constants are evaluated properly
229    fn = build_eval([p4])
230    fn()
231    assert p4.value == 2*math.pi*35
232
233    # Check additional context example; this also tests multiple
234    # expressions
235    class Table:
236        Si = 2.09
237        values = {'Si': 2.07}
238    tbl = Table()
239    p5 = Parameter('lookup',expression="tbl.Si")
240    fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
241    fn()
242    assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value)
243    p5.expression = "tbl.values['Si']"
244    fn = build_eval([p1,p2,p3,p5],context=dict(tbl=tbl))
245    fn()
246    assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value)
247   
248
249    # Verify that we capture invalid expressions
250    for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', 
251                 'piddle',
252                 '5; import sys; print "p0wned"',
253                 '__import__("sys").argv']:
254        try:
255            p6 = Parameter('broken',expression=expr)
256            fn = build_eval([p6])
257            fn()
258        except Exception,msg: pass
259        else:  raise "Failed to raise error for %s"%expr
260
261if __name__ == "__main__": test()
Note: See TracBrowser for help on using the repository browser.