source: sasmodels/sasmodels/codegen.py @ 59ee4db

Last change on this file since 59ee4db was 59ee4db, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

replace sincos with sin,cos in python cylinder example

  • Property mode set to 100644
File size: 19.9 KB
Line 
1"""
2    cdegen
3    ~~~~~~~
4
5    Extension to ast that allow ast -> python code generation.
6
7    :copyright: Copyright 2008 by Armin Ronacher.
8    :license: BSD.
9"""
10from __future__ import print_function, division
11
12import ast
13from ast import NodeVisitor
14
15BINOP_SYMBOLS = {}
16BINOP_SYMBOLS[ast.Add] = '+'
17BINOP_SYMBOLS[ast.Sub] = '-'
18BINOP_SYMBOLS[ast.Mult] = '*'
19BINOP_SYMBOLS[ast.Div] = '/'
20BINOP_SYMBOLS[ast.Mod] = '%'
21BINOP_SYMBOLS[ast.Pow] = '**'
22BINOP_SYMBOLS[ast.LShift] = '<<'
23BINOP_SYMBOLS[ast.RShift] = '>>'
24BINOP_SYMBOLS[ast.BitOr] = '|'
25BINOP_SYMBOLS[ast.BitXor] = '^'
26BINOP_SYMBOLS[ast.BitAnd] = '&'
27BINOP_SYMBOLS[ast.FloorDiv] = '//'
28
29BOOLOP_SYMBOLS = {}
30BOOLOP_SYMBOLS[ast.And] = 'and'
31BOOLOP_SYMBOLS[ast.Or] = 'or'
32
33CMPOP_SYMBOLS = {}
34CMPOP_SYMBOLS[ast.Eq] = '=='
35CMPOP_SYMBOLS[ast.NotEq] = '!='
36CMPOP_SYMBOLS[ast.Lt] = '<'
37CMPOP_SYMBOLS[ast.LtE] = '<='
38CMPOP_SYMBOLS[ast.Gt] = '>'
39CMPOP_SYMBOLS[ast.GtE] = '>='
40CMPOP_SYMBOLS[ast.Is] = 'is'
41CMPOP_SYMBOLS[ast.IsNot] = 'is not'
42CMPOP_SYMBOLS[ast.In] = 'in'
43CMPOP_SYMBOLS[ast.NotIn] = 'not in'
44
45UNARYOP_SYMBOLS = {}
46UNARYOP_SYMBOLS[ast.Invert] = '~'
47UNARYOP_SYMBOLS[ast.Not] = 'not'
48UNARYOP_SYMBOLS[ast.UAdd] = '+'
49UNARYOP_SYMBOLS[ast.USub] = '-'
50
51
52def translate(functions, constants):
53    # type: (List[Tuple[str, str, int]], Dict[str, Any]) -> str
54    snippets = []
55    for source, filename, offset in functions:
56        tree = ast.parse(source)
57        snippet = to_source(tree) #, filename, offset)
58        snippets.append(snippet)
59    return "\n".join(snippets)
60
61def to_source(node, indent_with=' ' * 4, add_line_information=False):
62    """This function can convert a node tree back into python sourcecode.
63    This is useful for debugging purposes, especially if you're dealing with
64    custom asts not generated by python itself.
65
66    It could be that the sourcecode is evaluable when the AST itself is not
67    compilable / evaluable.  The reason for this is that the AST contains some
68    more data than regular sourcecode does, which is dropped during
69    conversion.
70
71    Each level of indentation is replaced with `indent_with`.  Per default this
72    parameter is equal to four spaces as suggested by PEP 8, but it might be
73    adjusted to match the application's styleguide.
74
75    If `add_line_information` is set to `True` comments for the line numbers
76    of the nodes are added to the output.  This can be used to spot wrong line
77    number information of statement nodes.
78    """
79    generator = SourceGenerator(indent_with, add_line_information)
80    generator.visit(node)
81
82    return ''.join(generator.result)
83
84class SourceGenerator(NodeVisitor):
85    """This visitor is able to transform a well formed syntax tree into python
86    sourcecode.  For more details have a look at the docstring of the
87    `node_to_source` function.
88    """
89
90    def __init__(self, indent_with, add_line_information=False):
91        self.result = []
92        self.indent_with = indent_with
93        self.add_line_information = add_line_information
94        self.indentation = 0
95        self.new_lines = 0
96
97    def write(self, x):
98        if self.new_lines:
99            if self.result:
100                self.result.append('\n' * self.new_lines)
101            self.result.append(self.indent_with * self.indentation)
102            self.new_lines = 0
103        self.result.append(x)
104
105    def newline(self, node=None, extra=0):
106        self.new_lines = max(self.new_lines, 1 + extra)
107        if node is not None and self.add_line_information:
108            self.write('# line: %s' % node.lineno)
109            self.new_lines = 1
110
111    def body(self, statements):
112        self.new_line = True
113        self.indentation += 1
114        for stmt in statements:
115            self.visit(stmt)
116        self.indentation -= 1
117
118    def body_or_else(self, node):
119        self.body(node.body)
120        if node.orelse:
121            self.newline()
122            self.write('else:')
123            self.body(node.orelse)
124
125    def signature(self, node):
126        want_comma = []
127        def write_comma():
128            if want_comma:
129                self.write(', ')
130            else:
131                want_comma.append(True)
132
133        padding = [None] * (len(node.args) - len(node.defaults))
134        for arg, default in zip(node.args, padding + node.defaults):
135            write_comma()
136            self.visit(arg)
137            if default is not None:
138                self.write('=')
139                self.visit(default)
140        if node.vararg is not None:
141            write_comma()
142            try:
143                self.write('*' + node.vararg.arg)
144            except AttributeError:  # CRUFT: python 2
145                self.write('*' + node.vararg)
146        if node.kwarg is not None:
147            write_comma()
148            try:
149                self.write('**' + node.kwarg.arg)
150            except AttributeError:  # CRUFT: python 2
151                self.write('*' + node.vararg)
152
153    def decorators(self, node):
154        for decorator in node.decorator_list:
155            self.newline(decorator)
156            self.write('@')
157            self.visit(decorator)
158
159    # Statements
160
161    def visit_Assert(self, node):
162        self.newline(node)
163        self.write('assert ')
164        self.visit(node.test)
165        if node.msg is not None:
166           self.write(', ')
167           self.visit(node.msg)
168
169    def visit_Assign(self, node):
170        self.newline(node)
171        for idx, target in enumerate(node.targets):
172            if idx:
173                self.write(', ')
174            self.visit(target)
175        self.write(' = ')
176        self.visit(node.value)
177
178    def visit_AugAssign(self, node):
179        self.newline(node)
180        self.visit(node.target)
181        self.write(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
182        self.visit(node.value)
183
184    def visit_ImportFrom(self, node):
185        self.newline(node)
186        self.write('from %s%s import ' % ('.' * node.level, node.module))
187        for idx, item in enumerate(node.names):
188            if idx:
189                self.write(', ')
190            self.write(item)
191
192    def visit_Import(self, node):
193        self.newline(node)
194        for item in node.names:
195            self.write('import ')
196            self.visit(item)
197
198    def visit_Expr(self, node):
199        self.newline(node)
200        self.generic_visit(node)
201
202    def visit_FunctionDef(self, node):
203        self.newline(extra=1)
204        self.decorators(node)
205        self.newline(node)
206        self.write('def %s(' % node.name)
207        self.visit(node.args)
208        self.write('):')
209        self.body(node.body)
210
211    def visit_ClassDef(self, node):
212        have_args = []
213        def paren_or_comma():
214            if have_args:
215                self.write(', ')
216            else:
217                have_args.append(True)
218                self.write('(')
219
220        self.newline(extra=2)
221        self.decorators(node)
222        self.newline(node)
223        self.write('class %s' % node.name)
224        for base in node.bases:
225            paren_or_comma()
226            self.visit(base)
227        # XXX: the if here is used to keep this module compatible
228        #      with python 2.6.
229        if hasattr(node, 'keywords'):
230            for keyword in node.keywords:
231                paren_or_comma()
232                self.write(keyword.arg + '=')
233                self.visit(keyword.value)
234            if node.starargs is not None:
235                paren_or_comma()
236                self.write('*')
237                self.visit(node.starargs)
238            if node.kwargs is not None:
239                paren_or_comma()
240                self.write('**')
241                self.visit(node.kwargs)
242        self.write(have_args and '):' or ':')
243        self.body(node.body)
244
245    def visit_If(self, node):
246        self.newline(node)
247        self.write('if ')
248        self.visit(node.test)
249        self.write(':')
250        self.body(node.body)
251        while True:
252            else_ = node.orelse
253            if len(else_) == 0:
254                break
255            elif len(else_) == 1 and isinstance(else_[0], ast.If):
256                node = else_[0]
257                self.newline()
258                self.write('elif ')
259                self.visit(node.test)
260                self.write(':')
261                self.body(node.body)
262            else:
263                self.newline()
264                self.write('else:')
265                self.body(else_)
266                break
267
268    def visit_For(self, node):
269        self.newline(node)
270        self.write('for ')
271        self.visit(node.target)
272        self.write(' in ')
273        self.visit(node.iter)
274        self.write(':')
275        self.body_or_else(node)
276
277    def visit_While(self, node):
278        self.newline(node)
279        self.write('while ')
280        self.visit(node.test)
281        self.write(':')
282        self.body_or_else(node)
283
284    def visit_With(self, node):
285        self.newline(node)
286        self.write('with ')
287        self.visit(node.context_expr)
288        if node.optional_vars is not None:
289            self.write(' as ')
290            self.visit(node.optional_vars)
291        self.write(':')
292        self.body(node.body)
293
294    def visit_Pass(self, node):
295        self.newline(node)
296        self.write('pass')
297
298    def visit_Print(self, node):
299        # XXX: python 2.6 only
300        self.newline(node)
301        self.write('print ')
302        want_comma = False
303        if node.dest is not None:
304            self.write(' >> ')
305            self.visit(node.dest)
306            want_comma = True
307        for value in node.values:
308            if want_comma:
309                self.write(', ')
310            self.visit(value)
311            want_comma = True
312        if not node.nl:
313            self.write(',')
314
315    def visit_Delete(self, node):
316        self.newline(node)
317        self.write('del ')
318        for idx, target in enumerate(node):
319            if idx:
320                self.write(', ')
321            self.visit(target)
322
323    def visit_TryExcept(self, node):
324        self.newline(node)
325        self.write('try:')
326        self.body(node.body)
327        for handler in node.handlers:
328            self.visit(handler)
329
330    def visit_TryFinally(self, node):
331        self.newline(node)
332        self.write('try:')
333        self.body(node.body)
334        self.newline(node)
335        self.write('finally:')
336        self.body(node.finalbody)
337
338    def visit_Global(self, node):
339        self.newline(node)
340        self.write('global ' + ', '.join(node.names))
341
342    def visit_Nonlocal(self, node):
343        self.newline(node)
344        self.write('nonlocal ' + ', '.join(node.names))
345
346    def visit_Return(self, node):
347        self.newline(node)
348        if node.value is None:
349            self.write('return')
350        else:
351            self.write('return ')
352            self.visit(node.value)
353
354    def visit_Break(self, node):
355        self.newline(node)
356        self.write('break')
357
358    def visit_Continue(self, node):
359        self.newline(node)
360        self.write('continue')
361
362    def visit_Raise(self, node):
363        # XXX: Python 2.6 / 3.0 compatibility
364        self.newline(node)
365        self.write('raise')
366        if hasattr(node, 'exc') and node.exc is not None:
367            self.write(' ')
368            self.visit(node.exc)
369            if node.cause is not None:
370                self.write(' from ')
371                self.visit(node.cause)
372        elif hasattr(node, 'type') and node.type is not None:
373            self.visit(node.type)
374            if node.inst is not None:
375                self.write(', ')
376                self.visit(node.inst)
377            if node.tback is not None:
378                self.write(', ')
379                self.visit(node.tback)
380
381    # Expressions
382
383    def visit_Attribute(self, node):
384        self.visit(node.value)
385        self.write('.' + node.attr)
386
387    def visit_Call(self, node):
388        want_comma = []
389        def write_comma():
390            if want_comma:
391                self.write(', ')
392            else:
393                want_comma.append(True)
394
395        self.visit(node.func)
396        self.write('(')
397        for arg in node.args:
398            write_comma()
399            self.visit(arg)
400        for keyword in node.keywords:
401            write_comma()
402            self.write(keyword.arg + '=')
403            self.visit(keyword.value)
404        if getattr(node, 'starargs', None) is not None:
405            write_comma()
406            self.write('*')
407            self.visit(node.starargs)
408        if getattr(node, 'kwargs', None) is not None:
409            write_comma()
410            self.write('**')
411            self.visit(node.kwargs)
412        self.write(')')
413
414    def visit_Name(self, node):
415        self.write(node.id)
416
417    def visit_Str(self, node):
418        self.write(repr(node.s))
419
420    def visit_Bytes(self, node):
421        self.write(repr(node.s))
422
423    def visit_Num(self, node):
424        self.write(repr(node.n))
425
426    def visit_Tuple(self, node):
427        self.write('(')
428        idx = -1
429        for idx, item in enumerate(node.elts):
430            if idx:
431                self.write(', ')
432            self.visit(item)
433        self.write(idx and ')' or ',)')
434
435    def sequence_visit(left, right):
436        def visit(self, node):
437            self.write(left)
438            for idx, item in enumerate(node.elts):
439                if idx:
440                    self.write(', ')
441                self.visit(item)
442            self.write(right)
443        return visit
444
445    visit_List = sequence_visit('[', ']')
446    visit_Set = sequence_visit('{', '}')
447    del sequence_visit
448
449    def visit_Dict(self, node):
450        self.write('{')
451        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
452            if idx:
453                self.write(', ')
454            self.visit(key)
455            self.write(': ')
456            self.visit(value)
457        self.write('}')
458
459    def visit_BinOp(self, node):
460        self.visit(node.left)
461        self.write(' %s ' % BINOP_SYMBOLS[type(node.op)])
462        self.visit(node.right)
463
464    def visit_BoolOp(self, node):
465        self.write('(')
466        for idx, value in enumerate(node.values):
467            if idx:
468                self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
469            self.visit(value)
470        self.write(')')
471
472    def visit_Compare(self, node):
473        self.write('(')
474        self.visit(node.left)
475        for op, right in zip(node.ops, node.comparators):
476            self.write(' %s ' % CMPOP_SYMBOLS[type(op)])
477            self.visit(right)
478        self.write(')')
479
480    def visit_UnaryOp(self, node):
481        self.write('(')
482        op = UNARYOP_SYMBOLS[type(node.op)]
483        self.write(op)
484        if op == 'not':
485            self.write(' ')
486        self.visit(node.operand)
487        self.write(')')
488
489    def visit_Subscript(self, node):
490        self.visit(node.value)
491        self.write('[')
492        self.visit(node.slice)
493        self.write(']')
494
495    def visit_Slice(self, node):
496        if node.lower is not None:
497            self.visit(node.lower)
498        self.write(':')
499        if node.upper is not None:
500            self.visit(node.upper)
501        if node.step is not None:
502            self.write(':')
503            if not (isinstance(node.step, ast.Name) and node.step.id == 'None'):
504                self.visit(node.step)
505
506    def visit_ExtSlice(self, node):
507        for idx, item in node.dims:
508            if idx:
509                self.write(', ')
510            self.visit(item)
511
512    def visit_Yield(self, node):
513        self.write('yield ')
514        self.visit(node.value)
515
516    def visit_Lambda(self, node):
517        self.write('lambda ')
518        self.visit(node.args)
519        self.write(': ')
520        self.visit(node.body)
521
522    def visit_Ellipsis(self, node):
523        self.write('Ellipsis')
524
525    def generator_visit(left, right):
526        def visit(self, node):
527            self.write(left)
528            self.visit(node.elt)
529            for comprehension in node.generators:
530                self.visit(comprehension)
531            self.write(right)
532        return visit
533
534    visit_ListComp = generator_visit('[', ']')
535    visit_GeneratorExp = generator_visit('(', ')')
536    visit_SetComp = generator_visit('{', '}')
537    del generator_visit
538
539    def visit_DictComp(self, node):
540        self.write('{')
541        self.visit(node.key)
542        self.write(': ')
543        self.visit(node.value)
544        for comprehension in node.generators:
545            self.visit(comprehension)
546        self.write('}')
547
548    def visit_IfExp(self, node):
549        self.visit(node.body)
550        self.write(' if ')
551        self.visit(node.test)
552        self.write(' else ')
553        self.visit(node.orelse)
554
555    def visit_Starred(self, node):
556        self.write('*')
557        self.visit(node.value)
558
559    def visit_Repr(self, node):
560        # XXX: python 2.6 only
561        self.write('`')
562        self.visit(node.value)
563        self.write('`')
564
565    # Helper Nodes
566
567    def visit_alias(self, node):
568        self.write(node.name)
569        if node.asname is not None:
570            self.write(' as ' + node.asname)
571
572    def visit_comprehension(self, node):
573        self.write(' for ')
574        self.visit(node.target)
575        self.write(' in ')
576        self.visit(node.iter)
577        if node.ifs:
578            for if_ in node.ifs:
579                self.write(' if ')
580                self.visit(if_)
581
582    def visit_excepthandler(self, node):
583        self.newline(node)
584        self.write('except')
585        if node.type is not None:
586            self.write(' ')
587            self.visit(node.type)
588            if node.name is not None:
589                self.write(' as ')
590                self.visit(node.name)
591        self.write(':')
592        self.body(node.body)
593
594    def visit_arguments(self, node):
595        self.signature(node)
596
597# ===== inspect.getclosurevars backport begin =====
598
599# copied from python 3
600from inspect import ismethod, isfunction, ismodule
601from collections import namedtuple
602builtins = __builtins__
603ClosureVars = namedtuple('ClosureVars', 'nonlocals globals builtins unbound')
604def py2_getclosurevars(func):
605    """
606    Get the mapping of free variables to their current values.
607
608    Returns a named tuple of dicts mapping the current nonlocal, global
609    and builtin references as seen by the body of the function. A final
610    set of unbound names that could not be resolved is also provided.
611    """
612
613    if ismethod(func):
614        func = func.__func__
615
616    if not isfunction(func):
617        raise TypeError("'{!r}' is not a Python function".format(func))
618
619    code = func.__code__
620    # Nonlocal references are named in co_freevars and resolved
621    # by looking them up in __closure__ by positional index
622    if func.__closure__ is None:
623        nonlocal_vars = {}
624    else:
625        nonlocal_vars = {
626            var : cell.cell_contents
627            for var, cell in zip(code.co_freevars, func.__closure__)
628       }
629
630    # Global and builtin references are named in co_names and resolved
631    # by looking them up in __globals__ or __builtins__
632    global_ns = func.__globals__
633    builtin_ns = global_ns.get("__builtins__", builtins.__dict__)
634    if ismodule(builtin_ns):
635        builtin_ns = builtin_ns.__dict__
636    global_vars = {}
637    builtin_vars = {}
638    unbound_names = set()
639    for name in code.co_names:
640        if name in ("None", "True", "False"):
641            # Because these used to be builtins instead of keywords, they
642            # may still show up as name references. We ignore them.
643            continue
644        try:
645            global_vars[name] = global_ns[name]
646        except KeyError:
647            try:
648                builtin_vars[name] = builtin_ns[name]
649            except KeyError:
650                unbound_names.add(name)
651
652    return ClosureVars(nonlocal_vars, global_vars,
653                       builtin_vars, unbound_names)
654
655import inspect
656if not hasattr(inspect, 'getclosurevars'):
657    inspect.getclosurevars = py2_getclosurevars
658
659# ===== inspect.getclosurevars backport end ======
660
661def print_function(f=None):
662    """
663    Print out the code for the function
664    """
665    # Include some comments to see if they get printed
666    import ast
667    import inspect
668    print("function:", f.__code__.co_name)
669    print("closure:", inspect.getclosurevars(f))
670    print("locals:", f.__code__.co_nlocals, f.__code__.co_names)
671    if f is not None:
672        tree = ast.parse(inspect.getsource(f))
673        print(to_source(tree))
674
675from math import sin
676outside = '40 C'
677def _hello(*args, **kw):
678    x = sin(y) + cos(z.real)
679    print("world", outside)
680    a, b = x, y
681    [a, b] = (x, y)
682    (a, b) = [x, y]
683
684if __name__ == "__main__":
685    #print_function(print_function)
686    print_function(_hello)
Note: See TracBrowser for help on using the repository browser.