source: sasmodels/sasmodels/py2c.py @ 765d025

Last change on this file since 765d025 was 765d025, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

Merge remote-tracking branch 'upstream/beta_approx'

  • Property mode set to 100644
File size: 44.8 KB
Line 
1r"""
2py2c
3~~~~
4
5Convert simple numeric python code into C code.
6
7This code is intended to translate direct algorithms for scientific code
8(mostly if statements and for loops operating on double precision values)
9into C code. Unlike projects like numba, cython, pypy and nuitka, the
10:func:`translate` function returns the corresponding C which can then be
11compiled with tinycc or sent to the GPU using CUDA or OpenCL.
12
13There is special handling certain constructs, such as *for i in range* and
14small integer powers.
15
16**TODO: make a nice list of supported constructs***
17
18Imports are not supported, but they are at least ignored so that properly
19constructed code can be run via python or translated to C without change.
20
21Most other python constructs are **not** supported:
22* classes
23* builtin types (dict, set, list)
24* exceptions
25* with context
26* del
27* yield
28* async
29* list slicing
30* multiple return values
31* "is/is not", "in/not in" conditionals
32
33There is limited support for list and list comprehensions, so long as they
34can be represented by a fixed array whose size is known at compile time, and
35they are small enough to be stored on the stack.
36
37Variables definition in C
38-------------------------
39Defining variables within the translate function is a bit of a guess work,
40using following rules:
41*   By default, a variable is a 'double'.
42*   Variable in a for loop is an int.
43*   Variable that is references with brackets is an array of doubles. The
44    variable within the brackets is integer. For example, in the
45    reference 'var1[var2]', var1 is a double array, and var2 is an integer.
46*   Assignment to an argument makes that argument an array, and the index
47    in that assignment is 0.
48    For example, the following python code::
49        def func(arg1, arg2):
50            arg2 = 17.
51    is translated to the following C code::
52        double func(double arg1)
53        {
54            arg2[0] = 17.0;
55        }
56    For example, the following python code is translated to the
57    following C code::
58
59        def func(arg1, arg2):          double func(double arg1) {
60            arg2 = 17.                      arg2[0] = 17.0;
61                                        }
62*   All functions are defined as double, even if there is no
63    return statement.
64
65Debugging
66---------
67
68*print* is partially supported using a simple regular expression. This
69requires a stylized form. Be sure to use print as a function instead of
70the print statement. If you are including substition variables, use the
71% string substitution style. Include parentheses around the substitution
72tuple, even if there is only one item; do not include the final comma even
73if it is a single item (yes, it won't be a tuple, but it makes the regexp
74much simpler). Keep the item on a single line. Here are three forms that work::
75
76    print("x") => printf("x\n");
77    print("x %g"%(a)) => printf("x %g\n", a);
78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c);
79
80You can generate *main* using the *if __name__ == "__main__":* construct.
81This does a simple substitution with "def main():" before translation and
82a substitution with "int main(int argc, double *argv[])" after translation.
83The result is that the content of the *if* block becomes the content of *main*.
84Along with the print statement, you can run and test a translation standalone
85using::
86
87    python py2c.py source.py
88    cc source.c
89    ./a.out
90
91Known issues
92------------
93The following constructs may cause problems:
94
95* implicit arrays: possible namespace collision for variable "vec#"
96* swap fails: "x,y = y,x" will set x==y
97* top-level statements: code outside a function body causes errors
98* line number skew: each statement should be tagged with its own #line
99  to avoid skew as comments are skipped and loop bodies are wrapped with
100  braces, etc.
101* doesn't support docstrings
102
103References
104----------
105
106Based on a variant of codegen.py:
107
108    https://github.com/andreif/codegen
109    :copyright: Copyright 2008 by Armin Ronacher.
110    :license: BSD.
111"""
112
113# Update Notes
114# ============
115# 2017-11-22, OE: Each 'visit_*' method is to build a C statement string. It
116#                 shold insert 4 blanks per indentation level. The 'body'
117#                 method will combine all the strings, by adding the
118#                 'current_statement' to the c_proc string list
119# 2017-11-22, OE: variables, argument definition implemented.  Note: An
120#                 argument is considered an array if it is the target of an
121#                 assignment. In that case it is translated to <var>[0]
122# 2017-11-27, OE: 'pow' basicly working
123# 2017-12-07, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
124# 2017-12-07, OE: Power function, including special cases of
125#                 square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
126#                 translate_power, called from visit_BinOp
127# 2017-12-07, OE: Translation of integer division, '\\' in python, implemented
128#                 in translate_integer_divide, called from visit_BinOp
129# 2017-12-07, OE: C variable definition handled in 'define_c_vars'
130#               : Python integer division, '//', translated to C in
131#                 'translate_integer_divide'
132# 2017-12-15, OE: Precedence maintained by writing opening and closing
133#                 parenthesesm '(',')', in procedure 'visit_BinOp'.
134# 2017-12-18, OE: Added call to 'add_current_line()' at the beginning
135#                 of visit_Return
136# 2018-01-03, PK: Update interface for use in sasmodels
137# 2018-01-03, PK: support "expr if cond else expr" syntax
138# 2018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y))
139# 2018-01-03, PK: True/False => true/false
140# 2018-01-03, PK: f(x) was introducing an extra semicolon
141# 2018-01-03, PK: simplistic print function, for debugging
142# 2018-01-03, PK: while expr: ... => while (expr) { ... }
143# 2018-01-04, OE: Fixed bug in 'visit_If': visiting node.orelse in case else exists.
144
145from __future__ import print_function
146
147import sys
148import ast
149from ast import NodeVisitor
150from inspect import currentframe, getframeinfo
151
152try: # for debugging, astor lets us print out the node as python
153    import astor
154except ImportError:
155    pass
156
157BINOP_SYMBOLS = {}
158BINOP_SYMBOLS[ast.Add] = '+'
159BINOP_SYMBOLS[ast.Sub] = '-'
160BINOP_SYMBOLS[ast.Mult] = '*'
161BINOP_SYMBOLS[ast.Div] = '/'
162BINOP_SYMBOLS[ast.Mod] = '%'
163BINOP_SYMBOLS[ast.Pow] = '**'
164BINOP_SYMBOLS[ast.LShift] = '<<'
165BINOP_SYMBOLS[ast.RShift] = '>>'
166BINOP_SYMBOLS[ast.BitOr] = '|'
167BINOP_SYMBOLS[ast.BitXor] = '^'
168BINOP_SYMBOLS[ast.BitAnd] = '&'
169BINOP_SYMBOLS[ast.FloorDiv] = '//'
170
171BOOLOP_SYMBOLS = {}
172BOOLOP_SYMBOLS[ast.And] = '&&'
173BOOLOP_SYMBOLS[ast.Or] = '||'
174
175CMPOP_SYMBOLS = {}
176CMPOP_SYMBOLS[ast.Eq] = '=='
177CMPOP_SYMBOLS[ast.NotEq] = '!='
178CMPOP_SYMBOLS[ast.Lt] = '<'
179CMPOP_SYMBOLS[ast.LtE] = '<='
180CMPOP_SYMBOLS[ast.Gt] = '>'
181CMPOP_SYMBOLS[ast.GtE] = '>='
182CMPOP_SYMBOLS[ast.Is] = 'is'
183CMPOP_SYMBOLS[ast.IsNot] = 'is not'
184CMPOP_SYMBOLS[ast.In] = 'in'
185CMPOP_SYMBOLS[ast.NotIn] = 'not in'
186
187UNARYOP_SYMBOLS = {}
188UNARYOP_SYMBOLS[ast.Invert] = '~'
189UNARYOP_SYMBOLS[ast.Not] = 'not'
190UNARYOP_SYMBOLS[ast.UAdd] = '+'
191UNARYOP_SYMBOLS[ast.USub] = '-'
192
193
194# TODO: should not allow eval of arbitrary python
195def isevaluable(s):
196    try:
197        eval(s)
198        return True
199    except Exception:
200        return False
201
202def render_expression(tree):
203    generator = SourceGenerator()
204    generator.visit(tree)
205    c_code = "".join(generator.current_statement)
206    return c_code
207
208class SourceGenerator(NodeVisitor):
209    """This visitor is able to transform a well formed syntax tree into python
210    sourcecode.  For more details have a look at the docstring of the
211    `node_to_source` function.
212    """
213
214    def __init__(self, indent_with="    ", constants=None, fname=None, lineno=0):
215        self.indent_with = indent_with
216        self.indentation = 0
217
218        # for C
219        self.c_proc = []
220        self.signature_line = 0
221        self.arguments = []
222        self.current_function = ""
223        self.fname = fname
224        self.lineno_offset = lineno
225        self.warnings = []
226        self.current_statement = ""
227        # TODO: use set rather than list for c_vars, ...
228        self.c_vars = []
229        self.c_int_vars = []
230        self.c_pointers = []
231        self.c_dcl_pointers = []
232        self.c_functions = []
233        self.c_vectors = []
234        self.c_constants = constants if constants is not None else {}
235        self.in_expr = False
236        self.in_subref = False
237        self.in_subscript = False
238        self.tuples = []
239        self.required_functions = []
240        self.visited_args = False
241
242    def write_c(self, statement):
243        # TODO: build up as a list rather than adding to string
244        self.current_statement += statement
245
246    def write_python(self, x):
247        raise NotImplementedError("shouldn't be trying to write pythnon")
248
249    def add_c_line(self, line):
250        indentation = self.indent_with * self.indentation
251        self.c_proc.append("".join((indentation, line, "\n")))
252
253    def add_current_line(self):
254        if self.current_statement:
255            self.add_c_line(self.current_statement)
256            self.current_statement = ''
257
258    def add_unique_var(self, new_var):
259        if new_var not in self.c_vars:
260            self.c_vars.append(str(new_var))
261
262    def write_sincos(self, node):
263        angle = str(node.args[0].id)
264        self.write_c(node.args[1].id + " = sin(" + angle + ");")
265        self.add_current_line()
266        self.write_c(node.args[2].id + " = cos(" + angle + ");")
267        self.add_current_line()
268        for arg in node.args:
269            self.add_unique_var(arg.id)
270
271    def track_lineno(self, node):
272        #print("newline", node, [s for s in dir(node) if not s.startswith('_')])
273        if hasattr(node, 'lineno'):
274            line = '#line %d "%s"\n' % (node.lineno+self.lineno_offset-1, self.fname)
275            self.c_proc.append(line)
276
277    def body(self, statements):
278        if self.current_statement:
279            self.add_current_line()
280        self.new_line = True
281        self.indentation += 1
282        for stmt in statements:
283            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
284            #    target_name = stmt.targets[0].id # target name needed for debug only
285            self.visit(stmt)
286        self.add_current_line() # just for breaking point. to be deleted.
287        self.indentation -= 1
288
289    def body_or_else(self, node):
290        self.body(node.body)
291        if node.orelse:
292            self.unsupported(node, "for...else/while...else not supported")
293
294            self.track_lineno(node)
295            self.write_c('else:')
296            self.body(node.orelse)
297
298    def signature(self, node):
299        want_comma = []
300        def write_comma():
301            if want_comma:
302                self.write_c(', ')
303            else:
304                want_comma.append(True)
305
306        # for C
307        for arg in node.args:
308            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
309            try:
310                arg_name = arg.arg
311            except AttributeError:
312                arg_name = arg.id
313            self.arguments.append(arg_name)
314
315        padding = [None] *(len(node.args) - len(node.defaults))
316        for arg, default in zip(node.args, padding + node.defaults):
317            if default is not None:
318                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
319                try:
320                    arg_name = arg.arg
321                except AttributeError:
322                    arg_name = arg.id
323                w_str = ("C does not support default parameters: %s=%s"
324                         % (arg_name, str(default.n)))
325                self.warnings.append(w_str)
326
327    def decorators(self, node):
328        if node.decorator_list:
329            self.unsupported(node.decorator_list[0])
330        for decorator in node.decorator_list:
331            self.trac_lineno(decorator)
332            self.write_python('@')
333            self.visit(decorator)
334
335    # Statements
336
337    def visit_Assert(self, node):
338        self.unsupported(node)
339
340        self.track_lineno(node)
341        self.write_c('assert ')
342        self.visit(node.test)
343        if node.msg is not None:
344            self.write_python(', ')
345            self.visit(node.msg)
346
347    def define_c_vars(self, target):
348        if hasattr(target, 'id'):
349        # a variable is considered an array if it apears in the agrument list
350        # and being assigned to. For example, the variable p in the following
351        # sniplet is a pointer, while q is not
352        # def somefunc(p, q):
353        #  p = q + 1
354        #  return
355        #
356            if target.id not in self.c_vars:
357                if target.id in self.arguments:
358                    idx = self.arguments.index(target.id)
359                    new_target = self.arguments[idx] + "[0]"
360                    if new_target not in self.c_pointers:
361                        target.id = new_target
362                        self.c_pointers.append(self.arguments[idx])
363                else:
364                    self.c_vars.append(target.id)
365
366    def add_semi_colon(self):
367        #semi_pos = self.current_statement.find(';')
368        #if semi_pos >= 0:
369        #    self.current_statement = self.current_statement.replace(';', '')
370        self.write_c(';')
371
372    def visit_Assign(self, node):
373        self.add_current_line()
374        self.track_lineno(node)
375        self.in_expr = True
376        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
377            if idx:
378                self.write_c(' = ')
379            self.define_c_vars(target)
380            self.visit(target)
381        # Capture assigned tuple names, if any
382        targets = self.tuples[:]
383        del self.tuples[:]
384        self.write_c(' = ')
385        self.visited_args = False
386        self.visit(node.value)
387        self.add_semi_colon()
388        self.add_current_line()
389        # Assign tuples to tuples, if any
390        # TODO: doesn't handle swap:  a,b = b,a
391        for target, item in zip(targets, self.tuples):
392            self.visit(target)
393            self.write_c(' = ')
394            self.visit(item)
395            self.add_semi_colon()
396            self.add_current_line()
397        #if self.is_sequence and not self.visited_args:
398        #    for target in node.targets:
399        #        if hasattr(target, 'id'):
400        #            if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
401        #                if target.id not in self.c_dcl_pointers:
402        #                    self.c_dcl_pointers.append(target.id)
403        #                    if target.id in self.c_vars:
404        #                        self.c_vars.remove(target.id)
405        self.current_statement = ''
406        self.in_expr = False
407
408    def visit_AugAssign(self, node):
409        if node.target.id not in self.c_vars:
410            if node.target.id not in self.arguments:
411                self.c_vars.append(node.target.id)
412        self.in_expr = True
413        self.visit(node.target)
414        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
415        self.visit(node.value)
416        self.add_semi_colon()
417        self.in_expr = False
418        self.add_current_line()
419
420    def visit_ImportFrom(self, node):
421        return  # import ignored
422        self.track_lineno(node)
423        self.write_python('from %s%s import ' %('.' * node.level, node.module))
424        for idx, item in enumerate(node.names):
425            if idx:
426                self.write_python(', ')
427            self.write_python(item)
428
429    def visit_Import(self, node):
430        return  # import ignored
431        self.track_lineno(node)
432        for item in node.names:
433            self.write_python('import ')
434            self.visit(item)
435
436    def visit_Expr(self, node):
437        #self.in_expr = True
438        #self.track_lineno(node)
439        self.generic_visit(node)
440        #self.in_expr = False
441
442    def write_c_pointers(self, start_var):
443        if self.c_dcl_pointers:
444            var_list = []
445            for c_ptr in self.c_dcl_pointers:
446                if c_ptr not in self.arguments:
447                    var_list.append("*" + c_ptr)
448                if c_ptr in self.c_vars:
449                    self.c_vars.remove(c_ptr)
450            if var_list:
451                c_dcl = "    double " + ", ".join(var_list) + ";\n"
452                self.c_proc.insert(start_var, c_dcl)
453                start_var += 1
454        return start_var
455
456    def insert_c_vars(self, start_var):
457        have_decls = False
458        start_var = self.write_c_pointers(start_var)
459        if self.c_int_vars:
460            for var in self.c_int_vars:
461                if var in self.c_vars:
462                    self.c_vars.remove(var)
463            decls = ", ".join(self.c_int_vars)
464            self.c_proc.insert(start_var, "    int " + decls + ";\n")
465            have_decls = True
466            start_var += 1
467
468        if self.c_vars:
469            decls = ", ".join(self.c_vars)
470            self.c_proc.insert(start_var, "    double " + decls + ";\n")
471            have_decls = True
472            start_var += 1
473
474        if self.c_vectors:
475            for vec_number, vec_value  in enumerate(self.c_vectors):
476                name = "vec" + str(vec_number + 1)
477                decl = "    double " + name + "[] = {" + vec_value + "};"
478                self.c_proc.insert(start_var, decl + "\n")
479                start_var += 1
480
481        del self.c_vars[:]
482        del self.c_int_vars[:]
483        del self.c_vectors[:]
484        del self.c_pointers[:]
485        del self.c_dcl_pointers[:]
486        if have_decls:
487            self.c_proc.insert(start_var, "\n")
488
489    def insert_signature(self):
490        arg_decls = []
491        for arg in self.arguments:
492            decl = "double " + arg
493            if arg in self.c_pointers:
494                decl += "[]"
495            arg_decls.append(decl)
496        args_str = ", ".join(arg_decls)
497        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
498        if self.signature_line >= 0:
499            self.c_proc.insert(self.signature_line, method_sig)
500
501    def visit_FunctionDef(self, node):
502        if self.current_function:
503            self.unsupported(node, "function within a function")
504        self.current_function = node.name
505
506        # remember the location of the next warning that will be inserted
507        # so that we can stuff the function name ahead of the warning list
508        # if any warnings are generated by the function.
509        warning_index = len(self.warnings)
510
511        self.decorators(node)
512        self.track_lineno(node)
513        self.arguments = []
514        self.visit(node.args)
515        # for C
516        self.signature_line = len(self.c_proc)
517        self.add_c_line("\n{")
518        start_vars = len(self.c_proc) + 1
519        self.body(node.body)
520        self.add_c_line("}\n")
521        self.insert_signature()
522        self.insert_c_vars(start_vars)
523        del self.c_pointers[:]
524        self.current_function = ""
525        if warning_index != len(self.warnings):
526            self.warnings.insert(warning_index, "Warning in function '" + node.name + "':")
527
528    def visit_ClassDef(self, node):
529        self.unsupported(node)
530
531        have_args = []
532        def paren_or_comma():
533            if have_args:
534                self.write_python(', ')
535            else:
536                have_args.append(True)
537                self.write_python('(')
538
539        self.decorators(node)
540        self.track_lineno(node)
541        self.write_python('class %s' % node.name)
542        for base in node.bases:
543            paren_or_comma()
544            self.visit(base)
545        # CRUFT: python 2.6 does not have "keywords" attribute
546        if hasattr(node, 'keywords'):
547            for keyword in node.keywords:
548                paren_or_comma()
549                self.write_python(keyword.arg + '=')
550                self.visit(keyword.value)
551            if node.starargs is not None:
552                paren_or_comma()
553                self.write_python('*')
554                self.visit(node.starargs)
555            if node.kwargs is not None:
556                paren_or_comma()
557                self.write_python('**')
558                self.visit(node.kwargs)
559        self.write_python(have_args and '):' or ':')
560        self.body(node.body)
561
562    def visit_If(self, node):
563
564        self.track_lineno(node)
565        self.write_c('if ')
566        self.in_expr = True
567        self.visit(node.test)
568        self.in_expr = False
569        self.write_c(' {')
570        self.body(node.body)
571        self.add_c_line('}')
572        while True:
573            else_ = node.orelse
574            if len(else_) == 0:
575                break
576            #elif hasattr(else_, 'orelse'):
577            elif len(else_) == 1 and isinstance(else_[0], ast.If):
578                node = else_[0]
579                self.track_lineno(node)
580                self.write_c('else if ')
581                self.in_expr = True
582                self.visit(node.test)
583                self.in_expr = False
584                self.write_c(' {')
585                self.body(node.body)
586                self.add_current_line()
587                self.add_c_line('}')
588                #break
589            else:
590                self.track_lineno(else_)
591                self.write_c('else {')
592                self.body(else_)
593                self.add_c_line('}')
594                break
595
596    def get_for_range(self, node):
597        stop = ""
598        start = '0'
599        step = '1'
600        for_args = []
601        temp_statement = self.current_statement
602        self.current_statement = ''
603        for arg in node.iter.args:
604            self.visit(arg)
605            for_args.append(self.current_statement)
606            self.current_statement = ''
607        self.current_statement = temp_statement
608        if len(for_args) == 1:
609            stop = for_args[0]
610        elif len(for_args) == 2:
611            start = for_args[0]
612            stop = for_args[1]
613        elif len(for_args) == 3:
614            start = for_args[0]
615            stop = for_args[1]
616            start = for_args[2]
617        else:
618            raise("Ilegal for loop parameters")
619        return start, stop, step
620
621    def add_c_int_var(self, name):
622        if name not in self.c_int_vars:
623            self.c_int_vars.append(name)
624
625    def visit_For(self, node):
626        # node: for iterator is stored in node.target.
627        # Iterator name is in node.target.id.
628        self.add_current_line()
629        fForDone = False
630        self.current_statement = ''
631        if hasattr(node.iter, 'func'):
632            if hasattr(node.iter.func, 'id'):
633                if node.iter.func.id == 'range':
634                    self.visit(node.target)
635                    iterator = self.current_statement
636                    self.current_statement = ''
637                    self.add_c_int_var(iterator)
638                    start, stop, step = self.get_for_range(node)
639                    self.write_c("for (" + iterator + "=" + str(start) +
640                                 " ; " + iterator + " < " + str(stop) +
641                                 " ; " + iterator + " += " + str(step) + ") {")
642                    self.body_or_else(node)
643                    self.write_c("}")
644                    fForDone = True
645        if not fForDone:
646            # Generate the statement that is causing the error
647            self.current_statement = ''
648            self.write_c('for ')
649            self.visit(node.target)
650            self.write_c(' in ')
651            self.visit(node.iter)
652            self.write_c(':')
653            # report the error
654            self.unsupported(node, "unsupported " + self.current_statement)
655
656    def visit_While(self, node):
657        self.track_lineno(node)
658        self.write_c('while ')
659        self.visit(node.test)
660        self.write_c(' {')
661        self.body_or_else(node)
662        self.write_c('}')
663        self.add_current_line()
664
665    def visit_With(self, node):
666        self.unsupported(node)
667
668        self.track_lineno(node)
669        self.write_python('with ')
670        self.visit(node.context_expr)
671        if node.optional_vars is not None:
672            self.write_python(' as ')
673            self.visit(node.optional_vars)
674        self.write_python(':')
675        self.body(node.body)
676
677    def visit_Pass(self, node):
678        #self.track_lineno(node)
679        #self.write_python('pass')
680        pass
681
682    def visit_Print(self, node):
683        self.unsupported(node)
684
685        # CRUFT: python 2.6 only
686        self.track_lineno(node)
687        self.write_c('print ')
688        want_comma = False
689        if node.dest is not None:
690            self.write_c(' >> ')
691            self.visit(node.dest)
692            want_comma = True
693        for value in node.values:
694            if want_comma:
695                self.write_c(', ')
696            self.visit(value)
697            want_comma = True
698        if not node.nl:
699            self.write_c(',')
700
701    def visit_Delete(self, node):
702        self.unsupported(node)
703
704        self.track_lineno(node)
705        self.write_python('del ')
706        for idx, target in enumerate(node):
707            if idx:
708                self.write_python(', ')
709            self.visit(target)
710
711    def visit_TryExcept(self, node):
712        self.unsupported(node)
713
714        self.track_linno(node)
715        self.write_python('try:')
716        self.body(node.body)
717        for handler in node.handlers:
718            self.visit(handler)
719
720    def visit_TryFinally(self, node):
721        self.unsupported(node)
722
723        self.track_lineno(node)
724        self.write_python('try:')
725        self.body(node.body)
726        self.track_lineno(node)
727        self.write_python('finally:')
728        self.body(node.finalbody)
729
730    def visit_Global(self, node):
731        self.unsupported(node)
732
733        self.track_lineno(node)
734        self.write_python('global ' + ', '.join(node.names))
735
736    def visit_Nonlocal(self, node):
737        self.track_lineno(node)
738        self.write_python('nonlocal ' + ', '.join(node.names))
739
740    def visit_Return(self, node):
741        self.add_current_line()
742        self.track_lineno(node)
743        self.in_expr = True
744        if node.value is None:
745            self.write_c('return')
746        else:
747            self.write_c('return ')
748            self.visit(node.value)
749        self.add_semi_colon()
750        self.in_expr = False
751        self.add_c_line(self.current_statement)
752        self.current_statement = ''
753
754    def visit_Break(self, node):
755        self.track_lineno(node)
756        self.write_c('break')
757
758    def visit_Continue(self, node):
759        self.track_lineno(node)
760        self.write_c('continue')
761
762    def visit_Raise(self, node):
763        self.unsupported(node)
764
765        # CRUFT: Python 2.6 / 3.0 compatibility
766        self.track_lineno(node)
767        self.write_python('raise')
768        if hasattr(node, 'exc') and node.exc is not None:
769            self.write_python(' ')
770            self.visit(node.exc)
771            if node.cause is not None:
772                self.write_python(' from ')
773                self.visit(node.cause)
774        elif hasattr(node, 'type') and node.type is not None:
775            self.visit(node.type)
776            if node.inst is not None:
777                self.write_python(', ')
778                self.visit(node.inst)
779            if node.tback is not None:
780                self.write_python(', ')
781                self.visit(node.tback)
782
783    # Expressions
784
785    def visit_Attribute(self, node):
786        self.unsupported(node, "attribute reference a.b not supported")
787
788        self.visit(node.value)
789        self.write_python('.' + node.attr)
790
791    def visit_Call(self, node):
792        want_comma = []
793        def write_comma():
794            if want_comma:
795                self.write_c(', ')
796            else:
797                want_comma.append(True)
798        if hasattr(node.func, 'id'):
799            if node.func.id not in self.c_functions:
800                self.c_functions.append(node.func.id)
801            if node.func.id == 'abs':
802                self.write_c("fabs ")
803            elif node.func.id == 'int':
804                self.write_c('(int) ')
805            elif node.func.id == "SINCOS":
806                self.write_sincos(node)
807                return
808            else:
809                self.visit(node.func)
810        else:
811            self.visit(node.func)
812        self.write_c('(')
813        for arg in node.args:
814            write_comma()
815            self.visited_args = True
816            self.visit(arg)
817        for keyword in node.keywords:
818            write_comma()
819            self.write_c(keyword.arg + '=')
820            self.visit(keyword.value)
821        if hasattr(node, 'starargs'):
822            if node.starargs is not None:
823                write_comma()
824                self.write_c('*')
825                self.visit(node.starargs)
826        if hasattr(node, 'kwargs'):
827            if node.kwargs is not None:
828                write_comma()
829                self.write_c('**')
830                self.visit(node.kwargs)
831        self.write_c(')')
832        if not self.in_expr:
833            self.add_semi_colon()
834
835    TRANSLATE_CONSTANTS = {
836        # python 2 uses normal name references through vist_Name
837        'True': 'true',
838        'False': 'false',
839        'None': 'NULL',  # "None" will probably fail for other reasons
840        # python 3 uses NameConstant
841        True: 'true',
842        False: 'false',
843        None: 'NULL',  # "None" will probably fail for other reasons
844        }
845
846    def visit_Name(self, node):
847        translation = self.TRANSLATE_CONSTANTS.get(node.id, None)
848        if translation:
849            self.write_c(translation)
850            return
851        self.write_c(node.id)
852        if node.id in self.c_pointers and not self.in_subref:
853            self.write_c("[0]")
854        name = ""
855        sub = node.id.find("[")
856        if sub > 0:
857            name = node.id[0:sub].strip()
858        else:
859            name = node.id
860        # add variable to c_vars if it ins't there yet, not an argument and not a number
861        if (name not in self.c_functions and name not in self.c_vars and
862                name not in self.c_int_vars and name not in self.arguments and
863                name not in self.c_constants and not name.isdigit()):
864            if self.in_subscript:
865                self.add_c_int_var(node.id)
866            else:
867                self.c_vars.append(node.id)
868
869    def visit_NameConstant(self, node):
870        translation = self.TRANSLATE_CONSTANTS.get(node.value, None)
871        if translation is not None:
872            self.write_c(translation)
873        else:
874            self.unsupported(node, "don't know how to translate %r"%node.value)
875
876    def visit_Str(self, node):
877        s = node.s
878        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
879        self.write_c('"')
880        self.write_c(s)
881        self.write_c('"')
882
883    def visit_Bytes(self, node):
884        s = node.s
885        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
886        self.write_c('"')
887        self.write_c(s)
888        self.write_c('"')
889
890    def visit_Num(self, node):
891        self.write_c(repr(node.n))
892
893    def visit_Tuple(self, node):
894        for idx, item in enumerate(node.elts):
895            if idx:
896                self.tuples.append(item)
897            else:
898                self.visit(item)
899
900    def visit_List(self, node):
901        #self.unsupported(node)
902        #print("visiting", node)
903        #print(astor.to_source(node))
904        #print(node.elts)
905        exprs = [render_expression(item) for item in node.elts]
906        if exprs:
907            self.c_vectors.append(', '.join(exprs))
908            vec_name = "vec"  + str(len(self.c_vectors))
909            self.write_c(vec_name)
910
911    def visit_Set(self, node):
912        self.unsupported(node)
913
914    def visit_Dict(self, node):
915        self.unsupported(node)
916
917        self.write_python('{')
918        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
919            if idx:
920                self.write_python(', ')
921            self.visit(key)
922            self.write_python(': ')
923            self.visit(value)
924        self.write_python('}')
925
926    def get_special_power(self, string):
927        function_name = ''
928        is_negative_exp = False
929        if isevaluable(str(self.current_statement)):
930            exponent = eval(string)
931            is_negative_exp = exponent < 0
932            abs_exponent = abs(exponent)
933            if abs_exponent == 2:
934                function_name = "square"
935            elif abs_exponent == 3:
936                function_name = "cube"
937            elif abs_exponent == 0.5:
938                function_name = "sqrt"
939            elif abs_exponent == 1.0/3.0:
940                function_name = "cbrt"
941        if function_name == '':
942            function_name = "pow"
943        return function_name, is_negative_exp
944
945    def translate_power(self, node):
946        # get exponent by visiting the right hand argument.
947        function_name = "pow"
948        temp_statement = self.current_statement
949        # 'visit' functions write the results to the 'current_statement' class memnber
950        # Here, a temporary variable, 'temp_statement', is used, that enables the
951        # use of the 'visit' function
952        self.current_statement = ''
953        self.visit(node.right)
954        exponent = self.current_statement.replace(' ', '')
955        function_name, is_negative_exp = self.get_special_power(self.current_statement)
956        self.current_statement = temp_statement
957        if is_negative_exp:
958            self.write_c("1.0 /(")
959        self.write_c(function_name + "(")
960        self.visit(node.left)
961        if function_name == "pow":
962            self.write_c(", ")
963            self.visit(node.right)
964        self.write_c(")")
965        if is_negative_exp:
966            self.write_c(")")
967        #self.write_c(" ")
968
969    def translate_integer_divide(self, node):
970        self.write_c("(int)((")
971        self.visit(node.left)
972        self.write_c(")/(")
973        self.visit(node.right)
974        self.write_c("))")
975
976    def translate_float_divide(self, node):
977        self.write_c("((double)(")
978        self.visit(node.left)
979        self.write_c(")/(double)(")
980        self.visit(node.right)
981        self.write_c("))")
982
983    def visit_BinOp(self, node):
984        self.write_c("(")
985        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
986            self.translate_power(node)
987        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
988            self.translate_integer_divide(node)
989        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]:
990            self.translate_float_divide(node)
991        else:
992            self.visit(node.left)
993            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
994            self.visit(node.right)
995        self.write_c(")")
996
997    # for C
998    def visit_BoolOp(self, node):
999        self.write_c('(')
1000        for idx, value in enumerate(node.values):
1001            if idx:
1002                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
1003            self.visit(value)
1004        self.write_c(')')
1005
1006    def visit_Compare(self, node):
1007        self.write_c('(')
1008        self.visit(node.left)
1009        for op, right in zip(node.ops, node.comparators):
1010            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1011            self.visit(right)
1012        self.write_c(')')
1013
1014    def visit_UnaryOp(self, node):
1015        self.write_c('(')
1016        op = UNARYOP_SYMBOLS[type(node.op)]
1017        self.write_c(op)
1018        if op == 'not':
1019            self.write_c(' ')
1020        self.visit(node.operand)
1021        self.write_c(')')
1022
1023    def visit_Subscript(self, node):
1024        if node.value.id not in self.c_constants:
1025            if node.value.id not in self.c_pointers:
1026                self.c_pointers.append(node.value.id)
1027        self.in_subref = True
1028        self.visit(node.value)
1029        self.in_subref = False
1030        self.write_c('[')
1031        self.in_subscript = True
1032        self.visit(node.slice)
1033        self.in_subscript = False
1034        self.write_c(']')
1035
1036    def visit_Slice(self, node):
1037        if node.lower is not None:
1038            self.visit(node.lower)
1039        self.write_python(':')
1040        if node.upper is not None:
1041            self.visit(node.upper)
1042        if node.step is not None:
1043            self.write_python(':')
1044            if not(isinstance(node.step, Name) and node.step.id == 'None'):
1045                self.visit(node.step)
1046
1047    def visit_ExtSlice(self, node):
1048        for idx, item in node.dims:
1049            if idx:
1050                self.write_python(', ')
1051            self.visit(item)
1052
1053    def visit_Yield(self, node):
1054        self.unsupported(node)
1055
1056        self.write_python('yield ')
1057        self.visit(node.value)
1058
1059    def visit_Lambda(self, node):
1060        self.unsupported(node)
1061
1062        self.write_python('lambda ')
1063        self.visit(node.args)
1064        self.write_python(': ')
1065        self.visit(node.body)
1066
1067    def visit_Ellipsis(self, node):
1068        self.unsupported(node)
1069
1070        self.write_python('Ellipsis')
1071
1072    def generator_visit(left, right):
1073        def visit(self, node):
1074            self.write_python(left)
1075            self.write_c(left)
1076            self.visit(node.elt)
1077            for comprehension in node.generators:
1078                self.visit(comprehension)
1079            self.write_c(right)
1080            #self.write_python(right)
1081        return visit
1082
1083    visit_ListComp = generator_visit('[', ']')
1084    visit_GeneratorExp = generator_visit('(', ')')
1085    visit_SetComp = generator_visit('{', '}')
1086    del generator_visit
1087
1088    def visit_DictComp(self, node):
1089        self.unsupported(node)
1090
1091        self.write_python('{')
1092        self.visit(node.key)
1093        self.write_python(': ')
1094        self.visit(node.value)
1095        for comprehension in node.generators:
1096            self.visit(comprehension)
1097        self.write_python('}')
1098
1099    def visit_IfExp(self, node):
1100        self.write_c('((')
1101        self.visit(node.test)
1102        self.write_c(')?(')
1103        self.visit(node.body)
1104        self.write_c('):(')
1105        self.visit(node.orelse)
1106        self.write_c('))')
1107
1108    def visit_Starred(self, node):
1109        self.write_c('*')
1110        self.visit(node.value)
1111
1112    def visit_Repr(self, node):
1113        # CRUFT: python 2.6 only
1114        self.write_c('`')
1115        self.visit(node.value)
1116        self.write_python('`')
1117
1118    # Helper Nodes
1119
1120    def visit_alias(self, node):
1121        self.unsupported(node)
1122
1123        self.write_python(node.name)
1124        if node.asname is not None:
1125            self.write_python(' as ' + node.asname)
1126
1127    def visit_comprehension(self, node):
1128        self.unsupported(node)
1129
1130        self.write_c(' for ')
1131        self.visit(node.target)
1132        self.write_c(' in ')
1133        #self.write_python(' in ')
1134        self.visit(node.iter)
1135        if node.ifs:
1136            for if_ in node.ifs:
1137                self.write_python(' if ')
1138                self.visit(if_)
1139
1140    def visit_arguments(self, node):
1141        self.signature(node)
1142
1143    def unsupported(self, node, message=None):
1144        if hasattr(node, "value"):
1145            lineno = node.value.lineno
1146        elif hasattr(node, "iter"):
1147            lineno = node.iter.lineno
1148        else:
1149            #print(dir(node))
1150            lineno = 0
1151
1152        lineno += self.lineno_offset
1153        if self.fname:
1154            location = "%s(%d)" % (self.fname, lineno)
1155        else:
1156            location = "%d" % (self.fname, lineno)
1157        if self.current_function:
1158            location += ", function %s" % self.current_function
1159        if message is None:
1160            message = node.__class__.__name__ + " syntax not supported"
1161        raise SyntaxError("[%s] %s" % (location, message))
1162
1163def print_function(f=None):
1164    """
1165    Print out the code for the function
1166    """
1167    # Include some comments to see if they get printed
1168    import ast
1169    import inspect
1170    if f is not None:
1171        tree = ast.parse(inspect.getsource(f))
1172        tree_source = to_source(tree)
1173        print(tree_source)
1174
1175def define_constant(name, value, block_size=1):
1176    # type: (str, any, int) -> str
1177    """
1178    Convert a python constant into a C constant of the same name.
1179
1180    Returns the C declaration of the constant as a string, possibly containing
1181    line feeds.  The string will not be indented.
1182
1183    Supports int, double and sequences of double.
1184    """
1185    const = "constant "  # OpenCL needs globals to be constant
1186    if isinstance(value, int):
1187        parts = [const + "int ", name, " = ", "%d"%value, ";\n"]
1188    elif isinstance(value, float):
1189        parts = [const + "double ", name, " = ", "%.15g"%value, ";\n"]
1190    else:
1191        try:
1192            len(value)
1193        except TypeError:
1194            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1195        # extend constant arrays to a multiple of 4; not sure if this
1196        # is necessary, but some OpenCL targets broke if the number
1197        # of parameters in the parameter table was not a multiple of 4,
1198        # so do it for all constant arrays to be safe.
1199        if len(value)%block_size != 0:
1200            value = list(value) + [0.]*(block_size - len(value)%block_size)
1201        elements = ["%.15g"%v for v in value]
1202        parts = [const + "double ", name, "[]", " = ",
1203                 "{\n   ", ", ".join(elements), "\n};\n"]
1204
1205    return "".join(parts)
1206
1207
1208# Modified from the following:
1209#
1210#    http://code.activestate.com/recipes/578272-topological-sort/
1211#    Copyright (C) 2012 Sam Denton
1212#    License: MIT
1213def ordered_dag(dag):
1214    # type: (Dict[T, Set[T]]) -> Iterator[T]
1215    """
1216    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1217    in order such that every key occurs after the keys it depends upon.
1218
1219    This is an iterator not a sequence.  To reverse it use::
1220
1221        reversed(tuple(ordered_dag(dag)))
1222
1223    Raise an error if there are any cycles.
1224
1225    Keys are arbitrary hashable values.
1226    """
1227    # Local import to make the function stand-alone, and easier to borrow
1228    from functools import reduce
1229
1230    dag = dag.copy()
1231
1232    # make leaves depend on the empty set
1233    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1234    dag.update({node: set() for node in leaves})
1235    while True:
1236        leaves = set(node for node, links in dag.items() if not links)
1237        if not leaves:
1238            break
1239        for node in leaves:
1240            yield node
1241        dag = {node: (links-leaves)
1242               for node, links in dag.items() if node not in leaves}
1243    if dag:
1244        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1245                         % ", ".join(str(node) for node in dag.keys()))
1246
1247import re
1248PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n')
1249SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n'
1250PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n')
1251SUBST_STR = r'printf("\g<template>\n")'
1252def translate(functions, constants=None):
1253    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str]
1254    """
1255    Convert a list of functions to a list of C code strings.
1256
1257    Returns list of corresponding code snippets (with trailing lines in
1258    each block) and a list of warnings generated by the translator.
1259
1260    A function is given by the tuple (source, filename, line number).
1261
1262    Global constants are given in a dictionary of {name: value}.  The
1263    constants are used for name space resolution and type inferencing.
1264    Constants are not translated by this code. Instead, call
1265    :func:`define_constant` with name and value, and maybe block_size
1266    if arrays need to be padded to the next block boundary.
1267
1268    Function prototypes are not generated. Use :func:`ordered_dag`
1269    to list the functions in reverse order of dependency before calling
1270    translate. [Maybe a future revision will return the function prototypes
1271    so that a suitable "*.h" file can be generated.
1272    """
1273    snippets = []
1274    warnings = []
1275    for source, fname, lineno in functions:
1276        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\'))
1277        snippets.append(line_directive)
1278        # Replace simple print function calls with printf statements
1279        source = PRINT_ARGS.sub(SUBST_ARGS, source)
1280        source = PRINT_STR.sub(SUBST_STR, source)
1281        tree = ast.parse(source)
1282        generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
1283        generator.visit(tree)
1284        c_code = "".join(generator.c_proc)
1285        snippets.append(c_code)
1286        warnings.extend(generator.warnings)
1287    return snippets, warnings
1288
1289
1290C_HEADER_LINENO = getframeinfo(currentframe()).lineno + 2
1291C_HEADER = """
1292#line %d "%s"
1293#include <stdio.h>
1294#include <stdbool.h>
1295#include <math.h>
1296#define constant const
1297double square(double x) { return x*x; }
1298double cube(double x) { return x*x*x; }
1299double polyval(constant double *coef, double x, int N)
1300{
1301    int i = 0;
1302    double ans = coef[0];
1303
1304    while (i < N) {
1305        ans = ans * x + coef[i++];
1306    }
1307
1308    return ans;
1309}
1310"""
1311
1312USAGE = """\
1313Usage: python py2c.py <infile> [<outfile>]
1314
1315if outfile is omitted, output file is '<infile>.c'
1316"""
1317
1318def main():
1319    import os
1320    #print("Parsing...using Python" + sys.version)
1321    if len(sys.argv) == 1:
1322        print(USAGE)
1323        return
1324
1325    fname_in = sys.argv[1]
1326    if len(sys.argv) == 2:
1327        fname_base = os.path.splitext(fname_in)[0]
1328        fname_out = str(fname_base) + '.c'
1329    else:
1330        fname_out = sys.argv[2]
1331
1332    with open(fname_in, "r") as python_file:
1333        code = python_file.read()
1334    name = "gauss"
1335    code = (code
1336            .replace(name+'.n', 'GAUSS_N')
1337            .replace(name+'.z', 'GAUSS_Z')
1338            .replace(name+'.w', 'GAUSS_W')
1339            .replace('if __name__ == "__main__"', "def main()")
1340           )
1341
1342    translation, warnings = translate([(code, fname_in, 1)])
1343    c_code = "".join(translation)
1344    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])")
1345
1346    with open(fname_out, "w") as file_out:
1347        file_out.write(C_HEADER%(C_HEADER_LINENO, __file__))
1348        file_out.write(c_code)
1349
1350    if warnings:
1351        print("\n".join(warnings))
1352    #print("...Done")
1353
1354if __name__ == "__main__":
1355    main()
Note: See TracBrowser for help on using the repository browser.