source: sasmodels/sasmodels/codegen.py @ db03406

Last change on this file since db03406 was db03406, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

first pass at python to C translator for kernels

  • Property mode set to 100644
File size: 19.6 KB
Line 
1"""
2    cdegen
3    ~~~~~~~
4
5    Extension to ast that allow ast -> python code generation.
6
7    :copyright: Copyright 2008 by Armin Ronacher.
8    :license: BSD.
9"""
10from __future__ import print_function, division
11
12import ast
13from ast import NodeVisitor
14
15BINOP_SYMBOLS = {}
16BINOP_SYMBOLS[ast.Add] = '+'
17BINOP_SYMBOLS[ast.Sub] = '-'
18BINOP_SYMBOLS[ast.Mult] = '*'
19BINOP_SYMBOLS[ast.Div] = '/'
20BINOP_SYMBOLS[ast.Mod] = '%'
21BINOP_SYMBOLS[ast.Pow] = '**'
22BINOP_SYMBOLS[ast.LShift] = '<<'
23BINOP_SYMBOLS[ast.RShift] = '>>'
24BINOP_SYMBOLS[ast.BitOr] = '|'
25BINOP_SYMBOLS[ast.BitXor] = '^'
26BINOP_SYMBOLS[ast.BitAnd] = '&'
27BINOP_SYMBOLS[ast.FloorDiv] = '//'
28
29BOOLOP_SYMBOLS = {}
30BOOLOP_SYMBOLS[ast.And] = 'and'
31BOOLOP_SYMBOLS[ast.Or] = 'or'
32
33CMPOP_SYMBOLS = {}
34CMPOP_SYMBOLS[ast.Eq] = '=='
35CMPOP_SYMBOLS[ast.NotEq] = '!='
36CMPOP_SYMBOLS[ast.Lt] = '<'
37CMPOP_SYMBOLS[ast.LtE] = '<='
38CMPOP_SYMBOLS[ast.Gt] = '>'
39CMPOP_SYMBOLS[ast.GtE] = '>='
40CMPOP_SYMBOLS[ast.Is] = 'is'
41CMPOP_SYMBOLS[ast.IsNot] = 'is not'
42CMPOP_SYMBOLS[ast.In] = 'in'
43CMPOP_SYMBOLS[ast.NotIn] = 'not in'
44
45UNARYOP_SYMBOLS = {}
46UNARYOP_SYMBOLS[ast.Invert] = '~'
47UNARYOP_SYMBOLS[ast.Not] = 'not'
48UNARYOP_SYMBOLS[ast.UAdd] = '+'
49UNARYOP_SYMBOLS[ast.USub] = '-'
50
51
52def to_source(node, indent_with=' ' * 4, add_line_information=False):
53    """This function can convert a node tree back into python sourcecode.
54    This is useful for debugging purposes, especially if you're dealing with
55    custom asts not generated by python itself.
56
57    It could be that the sourcecode is evaluable when the AST itself is not
58    compilable / evaluable.  The reason for this is that the AST contains some
59    more data than regular sourcecode does, which is dropped during
60    conversion.
61
62    Each level of indentation is replaced with `indent_with`.  Per default this
63    parameter is equal to four spaces as suggested by PEP 8, but it might be
64    adjusted to match the application's styleguide.
65
66    If `add_line_information` is set to `True` comments for the line numbers
67    of the nodes are added to the output.  This can be used to spot wrong line
68    number information of statement nodes.
69    """
70    generator = SourceGenerator(indent_with, add_line_information)
71    generator.visit(node)
72
73    return ''.join(generator.result)
74
75class SourceGenerator(NodeVisitor):
76    """This visitor is able to transform a well formed syntax tree into python
77    sourcecode.  For more details have a look at the docstring of the
78    `node_to_source` function.
79    """
80
81    def __init__(self, indent_with, add_line_information=False):
82        self.result = []
83        self.indent_with = indent_with
84        self.add_line_information = add_line_information
85        self.indentation = 0
86        self.new_lines = 0
87
88    def write(self, x):
89        if self.new_lines:
90            if self.result:
91                self.result.append('\n' * self.new_lines)
92            self.result.append(self.indent_with * self.indentation)
93            self.new_lines = 0
94        self.result.append(x)
95
96    def newline(self, node=None, extra=0):
97        self.new_lines = max(self.new_lines, 1 + extra)
98        if node is not None and self.add_line_information:
99            self.write('# line: %s' % node.lineno)
100            self.new_lines = 1
101
102    def body(self, statements):
103        self.new_line = True
104        self.indentation += 1
105        for stmt in statements:
106            self.visit(stmt)
107        self.indentation -= 1
108
109    def body_or_else(self, node):
110        self.body(node.body)
111        if node.orelse:
112            self.newline()
113            self.write('else:')
114            self.body(node.orelse)
115
116    def signature(self, node):
117        want_comma = []
118        def write_comma():
119            if want_comma:
120                self.write(', ')
121            else:
122                want_comma.append(True)
123
124        padding = [None] * (len(node.args) - len(node.defaults))
125        for arg, default in zip(node.args, padding + node.defaults):
126            write_comma()
127            self.visit(arg)
128            if default is not None:
129                self.write('=')
130                self.visit(default)
131        if node.vararg is not None:
132            write_comma()
133            try:
134                self.write('*' + node.vararg.arg)
135            except AttributeError:  # CRUFT: python 2
136                self.write('*' + node.vararg)
137        if node.kwarg is not None:
138            write_comma()
139            try:
140                self.write('**' + node.kwarg.arg)
141            except AttributeError:  # CRUFT: python 2
142                self.write('*' + node.vararg)
143
144    def decorators(self, node):
145        for decorator in node.decorator_list:
146            self.newline(decorator)
147            self.write('@')
148            self.visit(decorator)
149
150    # Statements
151
152    def visit_Assert(self, node):
153        self.newline(node)
154        self.write('assert ')
155        self.visit(node.test)
156        if node.msg is not None:
157           self.write(', ')
158           self.visit(node.msg)
159
160    def visit_Assign(self, node):
161        self.newline(node)
162        for idx, target in enumerate(node.targets):
163            if idx:
164                self.write(', ')
165            self.visit(target)
166        self.write(' = ')
167        self.visit(node.value)
168
169    def visit_AugAssign(self, node):
170        self.newline(node)
171        self.visit(node.target)
172        self.write(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
173        self.visit(node.value)
174
175    def visit_ImportFrom(self, node):
176        self.newline(node)
177        self.write('from %s%s import ' % ('.' * node.level, node.module))
178        for idx, item in enumerate(node.names):
179            if idx:
180                self.write(', ')
181            self.write(item)
182
183    def visit_Import(self, node):
184        self.newline(node)
185        for item in node.names:
186            self.write('import ')
187            self.visit(item)
188
189    def visit_Expr(self, node):
190        self.newline(node)
191        self.generic_visit(node)
192
193    def visit_FunctionDef(self, node):
194        self.newline(extra=1)
195        self.decorators(node)
196        self.newline(node)
197        self.write('def %s(' % node.name)
198        self.visit(node.args)
199        self.write('):')
200        self.body(node.body)
201
202    def visit_ClassDef(self, node):
203        have_args = []
204        def paren_or_comma():
205            if have_args:
206                self.write(', ')
207            else:
208                have_args.append(True)
209                self.write('(')
210
211        self.newline(extra=2)
212        self.decorators(node)
213        self.newline(node)
214        self.write('class %s' % node.name)
215        for base in node.bases:
216            paren_or_comma()
217            self.visit(base)
218        # XXX: the if here is used to keep this module compatible
219        #      with python 2.6.
220        if hasattr(node, 'keywords'):
221            for keyword in node.keywords:
222                paren_or_comma()
223                self.write(keyword.arg + '=')
224                self.visit(keyword.value)
225            if node.starargs is not None:
226                paren_or_comma()
227                self.write('*')
228                self.visit(node.starargs)
229            if node.kwargs is not None:
230                paren_or_comma()
231                self.write('**')
232                self.visit(node.kwargs)
233        self.write(have_args and '):' or ':')
234        self.body(node.body)
235
236    def visit_If(self, node):
237        self.newline(node)
238        self.write('if ')
239        self.visit(node.test)
240        self.write(':')
241        self.body(node.body)
242        while True:
243            else_ = node.orelse
244            if len(else_) == 0:
245                break
246            elif len(else_) == 1 and isinstance(else_[0], ast.If):
247                node = else_[0]
248                self.newline()
249                self.write('elif ')
250                self.visit(node.test)
251                self.write(':')
252                self.body(node.body)
253            else:
254                self.newline()
255                self.write('else:')
256                self.body(else_)
257                break
258
259    def visit_For(self, node):
260        self.newline(node)
261        self.write('for ')
262        self.visit(node.target)
263        self.write(' in ')
264        self.visit(node.iter)
265        self.write(':')
266        self.body_or_else(node)
267
268    def visit_While(self, node):
269        self.newline(node)
270        self.write('while ')
271        self.visit(node.test)
272        self.write(':')
273        self.body_or_else(node)
274
275    def visit_With(self, node):
276        self.newline(node)
277        self.write('with ')
278        self.visit(node.context_expr)
279        if node.optional_vars is not None:
280            self.write(' as ')
281            self.visit(node.optional_vars)
282        self.write(':')
283        self.body(node.body)
284
285    def visit_Pass(self, node):
286        self.newline(node)
287        self.write('pass')
288
289    def visit_Print(self, node):
290        # XXX: python 2.6 only
291        self.newline(node)
292        self.write('print ')
293        want_comma = False
294        if node.dest is not None:
295            self.write(' >> ')
296            self.visit(node.dest)
297            want_comma = True
298        for value in node.values:
299            if want_comma:
300                self.write(', ')
301            self.visit(value)
302            want_comma = True
303        if not node.nl:
304            self.write(',')
305
306    def visit_Delete(self, node):
307        self.newline(node)
308        self.write('del ')
309        for idx, target in enumerate(node):
310            if idx:
311                self.write(', ')
312            self.visit(target)
313
314    def visit_TryExcept(self, node):
315        self.newline(node)
316        self.write('try:')
317        self.body(node.body)
318        for handler in node.handlers:
319            self.visit(handler)
320
321    def visit_TryFinally(self, node):
322        self.newline(node)
323        self.write('try:')
324        self.body(node.body)
325        self.newline(node)
326        self.write('finally:')
327        self.body(node.finalbody)
328
329    def visit_Global(self, node):
330        self.newline(node)
331        self.write('global ' + ', '.join(node.names))
332
333    def visit_Nonlocal(self, node):
334        self.newline(node)
335        self.write('nonlocal ' + ', '.join(node.names))
336
337    def visit_Return(self, node):
338        self.newline(node)
339        if node.value is None:
340            self.write('return')
341        else:
342            self.write('return ')
343            self.visit(node.value)
344
345    def visit_Break(self, node):
346        self.newline(node)
347        self.write('break')
348
349    def visit_Continue(self, node):
350        self.newline(node)
351        self.write('continue')
352
353    def visit_Raise(self, node):
354        # XXX: Python 2.6 / 3.0 compatibility
355        self.newline(node)
356        self.write('raise')
357        if hasattr(node, 'exc') and node.exc is not None:
358            self.write(' ')
359            self.visit(node.exc)
360            if node.cause is not None:
361                self.write(' from ')
362                self.visit(node.cause)
363        elif hasattr(node, 'type') and node.type is not None:
364            self.visit(node.type)
365            if node.inst is not None:
366                self.write(', ')
367                self.visit(node.inst)
368            if node.tback is not None:
369                self.write(', ')
370                self.visit(node.tback)
371
372    # Expressions
373
374    def visit_Attribute(self, node):
375        self.visit(node.value)
376        self.write('.' + node.attr)
377
378    def visit_Call(self, node):
379        want_comma = []
380        def write_comma():
381            if want_comma:
382                self.write(', ')
383            else:
384                want_comma.append(True)
385
386        self.visit(node.func)
387        self.write('(')
388        for arg in node.args:
389            write_comma()
390            self.visit(arg)
391        for keyword in node.keywords:
392            write_comma()
393            self.write(keyword.arg + '=')
394            self.visit(keyword.value)
395        if getattr(node, 'starargs', None) is not None:
396            write_comma()
397            self.write('*')
398            self.visit(node.starargs)
399        if getattr(node, 'kwargs', None) is not None:
400            write_comma()
401            self.write('**')
402            self.visit(node.kwargs)
403        self.write(')')
404
405    def visit_Name(self, node):
406        self.write(node.id)
407
408    def visit_Str(self, node):
409        self.write(repr(node.s))
410
411    def visit_Bytes(self, node):
412        self.write(repr(node.s))
413
414    def visit_Num(self, node):
415        self.write(repr(node.n))
416
417    def visit_Tuple(self, node):
418        self.write('(')
419        idx = -1
420        for idx, item in enumerate(node.elts):
421            if idx:
422                self.write(', ')
423            self.visit(item)
424        self.write(idx and ')' or ',)')
425
426    def sequence_visit(left, right):
427        def visit(self, node):
428            self.write(left)
429            for idx, item in enumerate(node.elts):
430                if idx:
431                    self.write(', ')
432                self.visit(item)
433            self.write(right)
434        return visit
435
436    visit_List = sequence_visit('[', ']')
437    visit_Set = sequence_visit('{', '}')
438    del sequence_visit
439
440    def visit_Dict(self, node):
441        self.write('{')
442        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
443            if idx:
444                self.write(', ')
445            self.visit(key)
446            self.write(': ')
447            self.visit(value)
448        self.write('}')
449
450    def visit_BinOp(self, node):
451        self.visit(node.left)
452        self.write(' %s ' % BINOP_SYMBOLS[type(node.op)])
453        self.visit(node.right)
454
455    def visit_BoolOp(self, node):
456        self.write('(')
457        for idx, value in enumerate(node.values):
458            if idx:
459                self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
460            self.visit(value)
461        self.write(')')
462
463    def visit_Compare(self, node):
464        self.write('(')
465        self.visit(node.left)
466        for op, right in zip(node.ops, node.comparators):
467            self.write(' %s ' % CMPOP_SYMBOLS[type(op)])
468            self.visit(right)
469        self.write(')')
470
471    def visit_UnaryOp(self, node):
472        self.write('(')
473        op = UNARYOP_SYMBOLS[type(node.op)]
474        self.write(op)
475        if op == 'not':
476            self.write(' ')
477        self.visit(node.operand)
478        self.write(')')
479
480    def visit_Subscript(self, node):
481        self.visit(node.value)
482        self.write('[')
483        self.visit(node.slice)
484        self.write(']')
485
486    def visit_Slice(self, node):
487        if node.lower is not None:
488            self.visit(node.lower)
489        self.write(':')
490        if node.upper is not None:
491            self.visit(node.upper)
492        if node.step is not None:
493            self.write(':')
494            if not (isinstance(node.step, ast.Name) and node.step.id == 'None'):
495                self.visit(node.step)
496
497    def visit_ExtSlice(self, node):
498        for idx, item in node.dims:
499            if idx:
500                self.write(', ')
501            self.visit(item)
502
503    def visit_Yield(self, node):
504        self.write('yield ')
505        self.visit(node.value)
506
507    def visit_Lambda(self, node):
508        self.write('lambda ')
509        self.visit(node.args)
510        self.write(': ')
511        self.visit(node.body)
512
513    def visit_Ellipsis(self, node):
514        self.write('Ellipsis')
515
516    def generator_visit(left, right):
517        def visit(self, node):
518            self.write(left)
519            self.visit(node.elt)
520            for comprehension in node.generators:
521                self.visit(comprehension)
522            self.write(right)
523        return visit
524
525    visit_ListComp = generator_visit('[', ']')
526    visit_GeneratorExp = generator_visit('(', ')')
527    visit_SetComp = generator_visit('{', '}')
528    del generator_visit
529
530    def visit_DictComp(self, node):
531        self.write('{')
532        self.visit(node.key)
533        self.write(': ')
534        self.visit(node.value)
535        for comprehension in node.generators:
536            self.visit(comprehension)
537        self.write('}')
538
539    def visit_IfExp(self, node):
540        self.visit(node.body)
541        self.write(' if ')
542        self.visit(node.test)
543        self.write(' else ')
544        self.visit(node.orelse)
545
546    def visit_Starred(self, node):
547        self.write('*')
548        self.visit(node.value)
549
550    def visit_Repr(self, node):
551        # XXX: python 2.6 only
552        self.write('`')
553        self.visit(node.value)
554        self.write('`')
555
556    # Helper Nodes
557
558    def visit_alias(self, node):
559        self.write(node.name)
560        if node.asname is not None:
561            self.write(' as ' + node.asname)
562
563    def visit_comprehension(self, node):
564        self.write(' for ')
565        self.visit(node.target)
566        self.write(' in ')
567        self.visit(node.iter)
568        if node.ifs:
569            for if_ in node.ifs:
570                self.write(' if ')
571                self.visit(if_)
572
573    def visit_excepthandler(self, node):
574        self.newline(node)
575        self.write('except')
576        if node.type is not None:
577            self.write(' ')
578            self.visit(node.type)
579            if node.name is not None:
580                self.write(' as ')
581                self.visit(node.name)
582        self.write(':')
583        self.body(node.body)
584
585    def visit_arguments(self, node):
586        self.signature(node)
587
588# ===== inspect.getclosurevars backport begin =====
589
590# copied from python 3
591from inspect import ismethod, isfunction, ismodule
592from collections import namedtuple
593builtins = __builtins__
594ClosureVars = namedtuple('ClosureVars', 'nonlocals globals builtins unbound')
595def py2_getclosurevars(func):
596    """
597    Get the mapping of free variables to their current values.
598
599    Returns a named tuple of dicts mapping the current nonlocal, global
600    and builtin references as seen by the body of the function. A final
601    set of unbound names that could not be resolved is also provided.
602    """
603
604    if ismethod(func):
605        func = func.__func__
606
607    if not isfunction(func):
608        raise TypeError("'{!r}' is not a Python function".format(func))
609
610    code = func.__code__
611    # Nonlocal references are named in co_freevars and resolved
612    # by looking them up in __closure__ by positional index
613    if func.__closure__ is None:
614        nonlocal_vars = {}
615    else:
616        nonlocal_vars = {
617            var : cell.cell_contents
618            for var, cell in zip(code.co_freevars, func.__closure__)
619       }
620
621    # Global and builtin references are named in co_names and resolved
622    # by looking them up in __globals__ or __builtins__
623    global_ns = func.__globals__
624    builtin_ns = global_ns.get("__builtins__", builtins.__dict__)
625    if ismodule(builtin_ns):
626        builtin_ns = builtin_ns.__dict__
627    global_vars = {}
628    builtin_vars = {}
629    unbound_names = set()
630    for name in code.co_names:
631        if name in ("None", "True", "False"):
632            # Because these used to be builtins instead of keywords, they
633            # may still show up as name references. We ignore them.
634            continue
635        try:
636            global_vars[name] = global_ns[name]
637        except KeyError:
638            try:
639                builtin_vars[name] = builtin_ns[name]
640            except KeyError:
641                unbound_names.add(name)
642
643    return ClosureVars(nonlocal_vars, global_vars,
644                       builtin_vars, unbound_names)
645
646import inspect
647if not hasattr(inspect, 'getclosurevars'):
648    inspect.getclosurevars = py2_getclosurevars
649
650# ===== inspect.getclosurevars backport end ======
651
652def print_function(f=None):
653    """
654    Print out the code for the function
655    """
656    # Include some comments to see if they get printed
657    import ast
658    import inspect
659    print("function:", f.__code__.co_name)
660    print("closure:", inspect.getclosurevars(f))
661    print("locals:", f.__code__.co_nlocals, f.__code__.co_names)
662    if f is not None:
663        tree = ast.parse(inspect.getsource(f))
664        print(to_source(tree))
665
666from math import sin
667outside = '40 C'
668def _hello(*args, **kw):
669    x = sin(y) + cos(z.real)
670    print("world", outside)
671
672if __name__ == "__main__":
673    #print_function(print_function)
674    print_function(_hello)
Note: See TracBrowser for help on using the repository browser.