source: sasmodels/sasmodels/py2c.py @ 6f91c91

Last change on this file since 6f91c91 was 6f91c91, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

track line number for c header

  • Property mode set to 100644
File size: 44.8 KB
RevLine 
[4c87de0]1r"""
2py2c
3~~~~
4
5Convert simple numeric python code into C code.
6
7This code is intended to translate direct algorithms for scientific code
8(mostly if statements and for loops operating on double precision values)
9into C code. Unlike projects like numba, cython, pypy and nuitka, the
10:func:`translate` function returns the corresponding C which can then be
11compiled with tinycc or sent to the GPU using CUDA or OpenCL.
12
13There is special handling certain constructs, such as *for i in range* and
14small integer powers.
15
16**TODO: make a nice list of supported constructs***
17
18Imports are not supported, but they are at least ignored so that properly
19constructed code can be run via python or translated to C without change.
20
21Most other python constructs are **not** supported:
22* classes
23* builtin types (dict, set, list)
24* exceptions
25* with context
26* del
27* yield
28* async
29* list slicing
30* multiple return values
31* "is/is not", "in/not in" conditionals
32
33There is limited support for list and list comprehensions, so long as they
34can be represented by a fixed array whose size is known at compile time, and
35they are small enough to be stored on the stack.
36
37Variables definition in C
38-------------------------
39Defining variables within the translate function is a bit of a guess work,
40using following rules:
41*   By default, a variable is a 'double'.
42*   Variable in a for loop is an int.
43*   Variable that is references with brackets is an array of doubles. The
44    variable within the brackets is integer. For example, in the
45    reference 'var1[var2]', var1 is a double array, and var2 is an integer.
46*   Assignment to an argument makes that argument an array, and the index
47    in that assignment is 0.
48    For example, the following python code::
49        def func(arg1, arg2):
50            arg2 = 17.
51    is translated to the following C code::
52        double func(double arg1)
53        {
54            arg2[0] = 17.0;
55        }
56    For example, the following python code is translated to the
57    following C code::
58
59        def func(arg1, arg2):          double func(double arg1) {
60            arg2 = 17.                      arg2[0] = 17.0;
61                                        }
62*   All functions are defined as double, even if there is no
63    return statement.
64
65Debugging
66---------
67
68*print* is partially supported using a simple regular expression. This
69requires a stylized form. Be sure to use print as a function instead of
70the print statement. If you are including substition variables, use the
71% string substitution style. Include parentheses around the substitution
72tuple, even if there is only one item; do not include the final comma even
73if it is a single item (yes, it won't be a tuple, but it makes the regexp
74much simpler). Keep the item on a single line. Here are three forms that work::
75
76    print("x") => printf("x\n");
77    print("x %g"%(a)) => printf("x %g\n", a);
78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c);
79
80You can generate *main* using the *if __name__ == "__main__":* construct.
81This does a simple substitution with "def main():" before translation and
82a substitution with "int main(int argc, double *argv[])" after translation.
83The result is that the content of the *if* block becomes the content of *main*.
84Along with the print statement, you can run and test a translation standalone
85using::
86
87    python py2c.py source.py
88    cc source.c
89    ./a.out
90
91Known issues
92------------
93The following constructs may cause problems:
94
95* implicit arrays: possible namespace collision for variable "vec#"
96* swap fails: "x,y = y,x" will set x==y
97* top-level statements: code outside a function body causes errors
98* line number skew: each statement should be tagged with its own #line
99  to avoid skew as comments are skipped and loop bodies are wrapped with
100  braces, etc.
101
102References
103----------
104
105Based on a variant of codegen.py:
106
107    https://github.com/andreif/codegen
[c01ed3e]108    :copyright: Copyright 2008 by Armin Ronacher.
109    :license: BSD.
110"""
[4c87de0]111
[6a37819]112# Update Notes
113# ============
[d5014e4]114# 2017-11-22, OE: Each 'visit_*' method is to build a C statement string. It
115#                 shold insert 4 blanks per indentation level. The 'body'
116#                 method will combine all the strings, by adding the
117#                 'current_statement' to the c_proc string list
118# 2017-11-22, OE: variables, argument definition implemented.  Note: An
119#                 argument is considered an array if it is the target of an
120#                 assignment. In that case it is translated to <var>[0]
121# 2017-11-27, OE: 'pow' basicly working
122# 2017-12-07, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
123# 2017-12-07, OE: Power function, including special cases of
[6a37819]124#                 square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
125#                 translate_power, called from visit_BinOp
[d5014e4]126# 2017-12-07, OE: Translation of integer division, '\\' in python, implemented
[6a37819]127#                 in translate_integer_divide, called from visit_BinOp
[d5014e4]128# 2017-12-07, OE: C variable definition handled in 'define_c_vars'
[6a37819]129#               : Python integer division, '//', translated to C in
130#                 'translate_integer_divide'
[d5014e4]131# 2017-12-15, OE: Precedence maintained by writing opening and closing
[6a37819]132#                 parenthesesm '(',')', in procedure 'visit_BinOp'.
[d5014e4]133# 2017-12-18, OE: Added call to 'add_current_line()' at the beginning
[6a37819]134#                 of visit_Return
135# 2018-01-03, PK: Update interface for use in sasmodels
136# 2018-01-03, PK: support "expr if cond else expr" syntax
137# 2018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y))
138# 2018-01-03, PK: True/False => true/false
139# 2018-01-03, PK: f(x) was introducing an extra semicolon
140# 2018-01-03, PK: simplistic print function, for debugging
141# 2018-01-03, PK: while expr: ... => while (expr) { ... }
[d5014e4]142# 2018-01-04, OE: Fixed bug in 'visit_If': visiting node.orelse in case else exists.
[6a37819]143
144from __future__ import print_function
[4c87de0]145
[71779b2]146import sys
[6a37819]147import ast
[71779b2]148from ast import NodeVisitor
[6f91c91]149from inspect import currentframe, getframeinfo
150
[5dd7cfb]151try: # for debugging, astor lets us print out the node as python
152    import astor
153except ImportError:
154    pass
[71779b2]155
156BINOP_SYMBOLS = {}
157BINOP_SYMBOLS[ast.Add] = '+'
158BINOP_SYMBOLS[ast.Sub] = '-'
159BINOP_SYMBOLS[ast.Mult] = '*'
160BINOP_SYMBOLS[ast.Div] = '/'
161BINOP_SYMBOLS[ast.Mod] = '%'
162BINOP_SYMBOLS[ast.Pow] = '**'
163BINOP_SYMBOLS[ast.LShift] = '<<'
164BINOP_SYMBOLS[ast.RShift] = '>>'
165BINOP_SYMBOLS[ast.BitOr] = '|'
166BINOP_SYMBOLS[ast.BitXor] = '^'
167BINOP_SYMBOLS[ast.BitAnd] = '&'
168BINOP_SYMBOLS[ast.FloorDiv] = '//'
169
170BOOLOP_SYMBOLS = {}
171BOOLOP_SYMBOLS[ast.And] = '&&'
[7b1dcf9]172BOOLOP_SYMBOLS[ast.Or] = '||'
[71779b2]173
174CMPOP_SYMBOLS = {}
[7b1dcf9]175CMPOP_SYMBOLS[ast.Eq] = '=='
[71779b2]176CMPOP_SYMBOLS[ast.NotEq] = '!='
177CMPOP_SYMBOLS[ast.Lt] = '<'
178CMPOP_SYMBOLS[ast.LtE] = '<='
179CMPOP_SYMBOLS[ast.Gt] = '>'
180CMPOP_SYMBOLS[ast.GtE] = '>='
181CMPOP_SYMBOLS[ast.Is] = 'is'
182CMPOP_SYMBOLS[ast.IsNot] = 'is not'
183CMPOP_SYMBOLS[ast.In] = 'in'
184CMPOP_SYMBOLS[ast.NotIn] = 'not in'
185
186UNARYOP_SYMBOLS = {}
187UNARYOP_SYMBOLS[ast.Invert] = '~'
188UNARYOP_SYMBOLS[ast.Not] = 'not'
189UNARYOP_SYMBOLS[ast.UAdd] = '+'
190UNARYOP_SYMBOLS[ast.USub] = '-'
191
192
[d5014e4]193# TODO: should not allow eval of arbitrary python
[3f9db6e]194def isevaluable(s):
195    try:
196        eval(s)
197        return True
[c01ed3e]198    except Exception:
[3f9db6e]199        return False
[fa74acf]200
[5dd7cfb]201def render_expression(tree):
202    generator = SourceGenerator()
203    generator.visit(tree)
204    c_code = "".join(generator.current_statement)
205    return c_code
206
[71779b2]207class SourceGenerator(NodeVisitor):
208    """This visitor is able to transform a well formed syntax tree into python
209    sourcecode.  For more details have a look at the docstring of the
210    `node_to_source` function.
211    """
212
[5dd7cfb]213    def __init__(self, indent_with="    ", constants=None, fname=None, lineno=0):
[71779b2]214        self.indent_with = indent_with
215        self.indentation = 0
[c01ed3e]216
[7b1dcf9]217        # for C
[c01ed3e]218        self.c_proc = []
[71779b2]219        self.signature_line = 0
220        self.arguments = []
[c01ed3e]221        self.current_function = ""
222        self.fname = fname
223        self.lineno_offset = lineno
[71779b2]224        self.warnings = []
225        self.current_statement = ""
[c01ed3e]226        # TODO: use set rather than list for c_vars, ...
227        self.c_vars = []
228        self.c_int_vars = []
229        self.c_pointers = []
230        self.c_dcl_pointers = []
231        self.c_functions = []
232        self.c_vectors = []
233        self.c_constants = constants if constants is not None else {}
[4c87de0]234        self.in_expr = False
[c01ed3e]235        self.in_subref = False
236        self.in_subscript = False
237        self.tuples = []
[71779b2]238        self.required_functions = []
239        self.visited_args = False
240
[c01ed3e]241    def write_c(self, statement):
242        # TODO: build up as a list rather than adding to string
243        self.current_statement += statement
[fa74acf]244
[5dd7cfb]245    def write_python(self, x):
246        raise NotImplementedError("shouldn't be trying to write pythnon")
247
[c01ed3e]248    def add_c_line(self, line):
249        indentation = self.indent_with * self.indentation
250        self.c_proc.append("".join((indentation, line, "\n")))
[71779b2]251
[fa74acf]252    def add_current_line(self):
[7b1dcf9]253        if self.current_statement:
[fa74acf]254            self.add_c_line(self.current_statement)
[71779b2]255            self.current_statement = ''
256
[c01ed3e]257    def add_unique_var(self, new_var):
258        if new_var not in self.c_vars:
259            self.c_vars.append(str(new_var))
[71779b2]260
[c01ed3e]261    def write_sincos(self, node):
[71779b2]262        angle = str(node.args[0].id)
[fa74acf]263        self.write_c(node.args[1].id + " = sin(" + angle + ");")
[71779b2]264        self.add_current_line()
[fa74acf]265        self.write_c(node.args[2].id + " = cos(" + angle + ");")
[71779b2]266        self.add_current_line()
267        for arg in node.args:
[c01ed3e]268            self.add_unique_var(arg.id)
[71779b2]269
[5dd7cfb]270    def track_lineno(self, node):
271        #print("newline", node, [s for s in dir(node) if not s.startswith('_')])
272        if hasattr(node, 'lineno'):
273            line = '#line %d "%s"\n' % (node.lineno+self.lineno_offset-1, self.fname)
274            self.c_proc.append(line)
[71779b2]275
276    def body(self, statements):
[7b1dcf9]277        if self.current_statement:
[fa74acf]278            self.add_current_line()
[71779b2]279        self.new_line = True
280        self.indentation += 1
281        for stmt in statements:
[7b1dcf9]282            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
283            #    target_name = stmt.targets[0].id # target name needed for debug only
[71779b2]284            self.visit(stmt)
[fa74acf]285        self.add_current_line() # just for breaking point. to be deleted.
[71779b2]286        self.indentation -= 1
287
288    def body_or_else(self, node):
289        self.body(node.body)
290        if node.orelse:
[15be191]291            self.unsupported(node, "for...else/while...else not supported")
292
[5dd7cfb]293            self.track_lineno(node)
[71779b2]294            self.write_c('else:')
295            self.body(node.orelse)
296
297    def signature(self, node):
298        want_comma = []
299        def write_comma():
300            if want_comma:
301                self.write_c(', ')
302            else:
303                want_comma.append(True)
[7b1dcf9]304
305        # for C
[71779b2]306        for arg in node.args:
[1ddb794]307            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
308            try:
309                arg_name = arg.arg
310            except AttributeError:
311                arg_name = arg.id
312            self.arguments.append(arg_name)
[71779b2]313
[fa74acf]314        padding = [None] *(len(node.args) - len(node.defaults))
[71779b2]315        for arg, default in zip(node.args, padding + node.defaults):
316            if default is not None:
[1ddb794]317                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
318                try:
319                    arg_name = arg.arg
320                except AttributeError:
321                    arg_name = arg.id
[d5014e4]322                w_str = ("C does not support default parameters: %s=%s"
[15be191]323                         % (arg_name, str(default.n)))
[fa74acf]324                self.warnings.append(w_str)
[71779b2]325
326    def decorators(self, node):
[5dd7cfb]327        if node.decorator_list:
328            self.unsupported(node.decorator_list[0])
[71779b2]329        for decorator in node.decorator_list:
[5dd7cfb]330            self.trac_lineno(decorator)
[71779b2]331            self.write_python('@')
332            self.visit(decorator)
333
334    # Statements
335
336    def visit_Assert(self, node):
[5dd7cfb]337        self.unsupported(node)
338
339        self.track_lineno(node)
[71779b2]340        self.write_c('assert ')
341        self.visit(node.test)
342        if node.msg is not None:
[fa74acf]343            self.write_python(', ')
344            self.visit(node.msg)
[71779b2]345
[c01ed3e]346    def define_c_vars(self, target):
[7b1dcf9]347        if hasattr(target, 'id'):
348        # a variable is considered an array if it apears in the agrument list
349        # and being assigned to. For example, the variable p in the following
350        # sniplet is a pointer, while q is not
351        # def somefunc(p, q):
352        #  p = q + 1
353        #  return
354        #
[c01ed3e]355            if target.id not in self.c_vars:
[7b1dcf9]356                if target.id in self.arguments:
[fa74acf]357                    idx = self.arguments.index(target.id)
[71779b2]358                    new_target = self.arguments[idx] + "[0]"
[c01ed3e]359                    if new_target not in self.c_pointers:
[71779b2]360                        target.id = new_target
[c01ed3e]361                        self.c_pointers.append(self.arguments[idx])
[71779b2]362                else:
[c01ed3e]363                    self.c_vars.append(target.id)
[71779b2]364
[fa74acf]365    def add_semi_colon(self):
[4c87de0]366        #semi_pos = self.current_statement.find(';')
367        #if semi_pos >= 0:
368        #    self.current_statement = self.current_statement.replace(';', '')
[98a4f14]369        self.write_c(';')
[937afef]370
[71779b2]371    def visit_Assign(self, node):
[fa74acf]372        self.add_current_line()
[5dd7cfb]373        self.track_lineno(node)
[4c87de0]374        self.in_expr = True
[71779b2]375        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
376            if idx:
377                self.write_c(' = ')
[c01ed3e]378            self.define_c_vars(target)
[71779b2]379            self.visit(target)
[c01ed3e]380        # Capture assigned tuple names, if any
381        targets = self.tuples[:]
382        del self.tuples[:]
[71779b2]383        self.write_c(' = ')
384        self.visited_args = False
385        self.visit(node.value)
[fa74acf]386        self.add_semi_colon()
387        self.add_current_line()
[c01ed3e]388        # Assign tuples to tuples, if any
389        # TODO: doesn't handle swap:  a,b = b,a
390        for target, item in zip(targets, self.tuples):
391            self.visit(target)
[71779b2]392            self.write_c(' = ')
393            self.visit(item)
[fa74acf]394            self.add_semi_colon()
395            self.add_current_line()
[5dd7cfb]396        #if self.is_sequence and not self.visited_args:
397        #    for target in node.targets:
398        #        if hasattr(target, 'id'):
399        #            if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
400        #                if target.id not in self.c_dcl_pointers:
401        #                    self.c_dcl_pointers.append(target.id)
402        #                    if target.id in self.c_vars:
403        #                        self.c_vars.remove(target.id)
[71779b2]404        self.current_statement = ''
[4c87de0]405        self.in_expr = False
[71779b2]406
407    def visit_AugAssign(self, node):
[c01ed3e]408        if node.target.id not in self.c_vars:
[7b1dcf9]409            if node.target.id not in self.arguments:
[c01ed3e]410                self.c_vars.append(node.target.id)
[4c87de0]411        self.in_expr = True
[71779b2]412        self.visit(node.target)
413        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
414        self.visit(node.value)
[fa74acf]415        self.add_semi_colon()
[4c87de0]416        self.in_expr = False
[fa74acf]417        self.add_current_line()
[71779b2]418
419    def visit_ImportFrom(self, node):
[c01ed3e]420        return  # import ignored
[5dd7cfb]421        self.track_lineno(node)
[fa74acf]422        self.write_python('from %s%s import ' %('.' * node.level, node.module))
[71779b2]423        for idx, item in enumerate(node.names):
424            if idx:
425                self.write_python(', ')
426            self.write_python(item)
427
428    def visit_Import(self, node):
[c01ed3e]429        return  # import ignored
[5dd7cfb]430        self.track_lineno(node)
[71779b2]431        for item in node.names:
432            self.write_python('import ')
433            self.visit(item)
434
435    def visit_Expr(self, node):
[4c87de0]436        #self.in_expr = True
[5dd7cfb]437        #self.track_lineno(node)
[71779b2]438        self.generic_visit(node)
[4c87de0]439        #self.in_expr = False
[71779b2]440
[c01ed3e]441    def write_c_pointers(self, start_var):
442        if self.c_dcl_pointers:
[7b1dcf9]443            var_list = []
[c01ed3e]444            for c_ptr in self.c_dcl_pointers:
[7b1dcf9]445                if c_ptr not in self.arguments:
446                    var_list.append("*" + c_ptr)
[c01ed3e]447                if c_ptr in self.c_vars:
448                    self.c_vars.remove(c_ptr)
[7b1dcf9]449            if var_list:
450                c_dcl = "    double " + ", ".join(var_list) + ";\n"
451                self.c_proc.insert(start_var, c_dcl)
[71c5f4d]452                start_var += 1
453        return start_var
454
[c01ed3e]455    def insert_c_vars(self, start_var):
456        have_decls = False
457        start_var = self.write_c_pointers(start_var)
458        if self.c_int_vars:
459            for var in self.c_int_vars:
460                if var in self.c_vars:
461                    self.c_vars.remove(var)
462            decls = ", ".join(self.c_int_vars)
463            self.c_proc.insert(start_var, "    int " + decls + ";\n")
464            have_decls = True
[71779b2]465            start_var += 1
[fa74acf]466
[c01ed3e]467        if self.c_vars:
468            decls = ", ".join(self.c_vars)
469            self.c_proc.insert(start_var, "    double " + decls + ";\n")
470            have_decls = True
[937afef]471            start_var += 1
[7b1dcf9]472
[c01ed3e]473        if self.c_vectors:
474            for vec_number, vec_value  in enumerate(self.c_vectors):
475                name = "vec" + str(vec_number + 1)
476                decl = "    double " + name + "[] = {" + vec_value + "};"
477                self.c_proc.insert(start_var, decl + "\n")
[71779b2]478                start_var += 1
[7b1dcf9]479
[c01ed3e]480        del self.c_vars[:]
481        del self.c_int_vars[:]
482        del self.c_vectors[:]
483        del self.c_pointers[:]
484        del self.c_dcl_pointers[:]
485        if have_decls:
[fa74acf]486            self.c_proc.insert(start_var, "\n")
[71779b2]487
[c01ed3e]488    def insert_signature(self):
[7b1dcf9]489        arg_decls = []
490        for arg in self.arguments:
491            decl = "double " + arg
[c01ed3e]492            if arg in self.c_pointers:
[7b1dcf9]493                decl += "[]"
494            arg_decls.append(decl)
495        args_str = ", ".join(arg_decls)
[c01ed3e]496        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
[7b1dcf9]497        if self.signature_line >= 0:
[c01ed3e]498            self.c_proc.insert(self.signature_line, method_sig)
[71779b2]499
500    def visit_FunctionDef(self, node):
[c01ed3e]501        if self.current_function:
502            self.unsupported(node, "function within a function")
503        self.current_function = node.name
504
[d5014e4]505        # remember the location of the next warning that will be inserted
506        # so that we can stuff the function name ahead of the warning list
507        # if any warnings are generated by the function.
[4339764]508        warning_index = len(self.warnings)
[d5014e4]509
[71779b2]510        self.decorators(node)
[5dd7cfb]511        self.track_lineno(node)
[71779b2]512        self.arguments = []
513        self.visit(node.args)
[7b1dcf9]514        # for C
[71779b2]515        self.signature_line = len(self.c_proc)
516        self.add_c_line("\n{")
517        start_vars = len(self.c_proc) + 1
518        self.body(node.body)
519        self.add_c_line("}\n")
[c01ed3e]520        self.insert_signature()
521        self.insert_c_vars(start_vars)
[15be191]522        del self.c_pointers[:]
[c01ed3e]523        self.current_function = ""
[d5014e4]524        if warning_index != len(self.warnings):
[4339764]525            self.warnings.insert(warning_index, "Warning in function '" + node.name + "':")
[71779b2]526
527    def visit_ClassDef(self, node):
[5dd7cfb]528        self.unsupported(node)
529
[71779b2]530        have_args = []
531        def paren_or_comma():
532            if have_args:
533                self.write_python(', ')
534            else:
535                have_args.append(True)
536                self.write_python('(')
537
538        self.decorators(node)
[5dd7cfb]539        self.track_lineno(node)
[71779b2]540        self.write_python('class %s' % node.name)
541        for base in node.bases:
542            paren_or_comma()
543            self.visit(base)
[7b1dcf9]544        # CRUFT: python 2.6 does not have "keywords" attribute
[71779b2]545        if hasattr(node, 'keywords'):
546            for keyword in node.keywords:
547                paren_or_comma()
548                self.write_python(keyword.arg + '=')
549                self.visit(keyword.value)
550            if node.starargs is not None:
551                paren_or_comma()
552                self.write_python('*')
553                self.visit(node.starargs)
554            if node.kwargs is not None:
555                paren_or_comma()
556                self.write_python('**')
557                self.visit(node.kwargs)
558        self.write_python(have_args and '):' or ':')
559        self.body(node.body)
560
561    def visit_If(self, node):
[4c87de0]562
[5dd7cfb]563        self.track_lineno(node)
[71779b2]564        self.write_c('if ')
[4c87de0]565        self.in_expr = True
[71779b2]566        self.visit(node.test)
[4c87de0]567        self.in_expr = False
[71779b2]568        self.write_c(' {')
569        self.body(node.body)
570        self.add_c_line('}')
571        while True:
572            else_ = node.orelse
573            if len(else_) == 0:
574                break
[7b1dcf9]575            #elif hasattr(else_, 'orelse'):
[71779b2]576            elif len(else_) == 1 and isinstance(else_[0], ast.If):
577                node = else_[0]
[5dd7cfb]578                self.track_lineno(node)
[71779b2]579                self.write_c('else if ')
[4c87de0]580                self.in_expr = True
[71779b2]581                self.visit(node.test)
[4c87de0]582                self.in_expr = False
[71779b2]583                self.write_c(' {')
584                self.body(node.body)
[fa74acf]585                self.add_current_line()
[71779b2]586                self.add_c_line('}')
[7b1dcf9]587                #break
[71779b2]588            else:
[5dd7cfb]589                self.track_lineno(else_)
[71779b2]590                self.write_c('else {')
[d5014e4]591                self.body(else_)
[71779b2]592                self.add_c_line('}')
593                break
594
[c01ed3e]595    def get_for_range(self, node):
[937afef]596        stop = ""
597        start = '0'
598        step = '1'
599        for_args = []
600        temp_statement = self.current_statement
601        self.current_statement = ''
602        for arg in node.iter.args:
603            self.visit(arg)
604            for_args.append(self.current_statement)
[71779b2]605            self.current_statement = ''
[937afef]606        self.current_statement = temp_statement
[7b1dcf9]607        if len(for_args) == 1:
[937afef]608            stop = for_args[0]
[7b1dcf9]609        elif len(for_args) == 2:
[937afef]610            start = for_args[0]
611            stop = for_args[1]
[7b1dcf9]612        elif len(for_args) == 3:
[937afef]613            start = for_args[0]
614            stop = for_args[1]
615            start = for_args[2]
616        else:
617            raise("Ilegal for loop parameters")
[7b1dcf9]618        return start, stop, step
[937afef]619
[d5014e4]620    def add_c_int_var(self, name):
621        if name not in self.c_int_vars:
622            self.c_int_vars.append(name)
[4339764]623
[937afef]624    def visit_For(self, node):
[7b1dcf9]625        # node: for iterator is stored in node.target.
626        # Iterator name is in node.target.id.
[937afef]627        self.add_current_line()
[71779b2]628        fForDone = False
[937afef]629        self.current_statement = ''
[7b1dcf9]630        if hasattr(node.iter, 'func'):
631            if hasattr(node.iter.func, 'id'):
632                if node.iter.func.id == 'range':
[937afef]633                    self.visit(node.target)
634                    iterator = self.current_statement
635                    self.current_statement = ''
[d5014e4]636                    self.add_c_int_var(iterator)
[c01ed3e]637                    start, stop, step = self.get_for_range(node)
[4c87de0]638                    self.write_c("for (" + iterator + "=" + str(start) +
[7b1dcf9]639                                 " ; " + iterator + " < " + str(stop) +
640                                 " ; " + iterator + " += " + str(step) + ") {")
[71779b2]641                    self.body_or_else(node)
[fa74acf]642                    self.write_c("}")
[71779b2]643                    fForDone = True
[7b1dcf9]644        if not fForDone:
[c01ed3e]645            # Generate the statement that is causing the error
[71779b2]646            self.current_statement = ''
647            self.write_c('for ')
648            self.visit(node.target)
649            self.write_c(' in ')
650            self.visit(node.iter)
651            self.write_c(':')
[c01ed3e]652            # report the error
[15be191]653            self.unsupported(node, "unsupported " + self.current_statement)
[71779b2]654
655    def visit_While(self, node):
[5dd7cfb]656        self.track_lineno(node)
[71779b2]657        self.write_c('while ')
658        self.visit(node.test)
[4c87de0]659        self.write_c(' {')
[71779b2]660        self.body_or_else(node)
[4c87de0]661        self.write_c('}')
662        self.add_current_line()
[71779b2]663
664    def visit_With(self, node):
[c01ed3e]665        self.unsupported(node)
[15be191]666
[5dd7cfb]667        self.track_lineno(node)
[71779b2]668        self.write_python('with ')
669        self.visit(node.context_expr)
670        if node.optional_vars is not None:
671            self.write_python(' as ')
672            self.visit(node.optional_vars)
673        self.write_python(':')
674        self.body(node.body)
675
676    def visit_Pass(self, node):
[5dd7cfb]677        #self.track_lineno(node)
[c01ed3e]678        #self.write_python('pass')
[5dd7cfb]679        pass
[71779b2]680
681    def visit_Print(self, node):
[c01ed3e]682        self.unsupported(node)
[15be191]683
[7b1dcf9]684        # CRUFT: python 2.6 only
[5dd7cfb]685        self.track_lineno(node)
[71779b2]686        self.write_c('print ')
687        want_comma = False
688        if node.dest is not None:
689            self.write_c(' >> ')
690            self.visit(node.dest)
691            want_comma = True
692        for value in node.values:
693            if want_comma:
694                self.write_c(', ')
695            self.visit(value)
696            want_comma = True
697        if not node.nl:
698            self.write_c(',')
699
700    def visit_Delete(self, node):
[c01ed3e]701        self.unsupported(node)
[15be191]702
[5dd7cfb]703        self.track_lineno(node)
[71779b2]704        self.write_python('del ')
705        for idx, target in enumerate(node):
706            if idx:
707                self.write_python(', ')
708            self.visit(target)
709
710    def visit_TryExcept(self, node):
[c01ed3e]711        self.unsupported(node)
[15be191]712
[5dd7cfb]713        self.track_linno(node)
[71779b2]714        self.write_python('try:')
715        self.body(node.body)
716        for handler in node.handlers:
717            self.visit(handler)
718
719    def visit_TryFinally(self, node):
[c01ed3e]720        self.unsupported(node)
[15be191]721
[5dd7cfb]722        self.track_lineno(node)
[71779b2]723        self.write_python('try:')
724        self.body(node.body)
[5dd7cfb]725        self.track_lineno(node)
[71779b2]726        self.write_python('finally:')
727        self.body(node.finalbody)
728
729    def visit_Global(self, node):
[c01ed3e]730        self.unsupported(node)
[15be191]731
[5dd7cfb]732        self.track_lineno(node)
[71779b2]733        self.write_python('global ' + ', '.join(node.names))
734
735    def visit_Nonlocal(self, node):
[5dd7cfb]736        self.track_lineno(node)
[71779b2]737        self.write_python('nonlocal ' + ', '.join(node.names))
738
739    def visit_Return(self, node):
[4c72117]740        self.add_current_line()
[5dd7cfb]741        self.track_lineno(node)
[4c87de0]742        self.in_expr = True
[71779b2]743        if node.value is None:
744            self.write_c('return')
745        else:
[d5014e4]746            self.write_c('return ')
[71779b2]747            self.visit(node.value)
[98a4f14]748        self.add_semi_colon()
[4c87de0]749        self.in_expr = False
[fa74acf]750        self.add_c_line(self.current_statement)
[71779b2]751        self.current_statement = ''
752
753    def visit_Break(self, node):
[5dd7cfb]754        self.track_lineno(node)
[71779b2]755        self.write_c('break')
756
757    def visit_Continue(self, node):
[5dd7cfb]758        self.track_lineno(node)
[71779b2]759        self.write_c('continue')
760
761    def visit_Raise(self, node):
[c01ed3e]762        self.unsupported(node)
[15be191]763
[7b1dcf9]764        # CRUFT: Python 2.6 / 3.0 compatibility
[5dd7cfb]765        self.track_lineno(node)
[71779b2]766        self.write_python('raise')
767        if hasattr(node, 'exc') and node.exc is not None:
768            self.write_python(' ')
769            self.visit(node.exc)
770            if node.cause is not None:
771                self.write_python(' from ')
772                self.visit(node.cause)
773        elif hasattr(node, 'type') and node.type is not None:
774            self.visit(node.type)
775            if node.inst is not None:
776                self.write_python(', ')
777                self.visit(node.inst)
778            if node.tback is not None:
779                self.write_python(', ')
780                self.visit(node.tback)
781
782    # Expressions
783
784    def visit_Attribute(self, node):
[c01ed3e]785        self.unsupported(node, "attribute reference a.b not supported")
[15be191]786
[71779b2]787        self.visit(node.value)
788        self.write_python('.' + node.attr)
789
790    def visit_Call(self, node):
791        want_comma = []
792        def write_comma():
793            if want_comma:
794                self.write_c(', ')
795            else:
796                want_comma.append(True)
[7b1dcf9]797        if hasattr(node.func, 'id'):
[c01ed3e]798            if node.func.id not in self.c_functions:
799                self.c_functions.append(node.func.id)
[7b1dcf9]800            if node.func.id == 'abs':
[fa74acf]801                self.write_c("fabs ")
[7b1dcf9]802            elif node.func.id == 'int':
[71779b2]803                self.write_c('(int) ')
[7b1dcf9]804            elif node.func.id == "SINCOS":
[c01ed3e]805                self.write_sincos(node)
[71779b2]806                return
807            else:
808                self.visit(node.func)
809        else:
810            self.visit(node.func)
811        self.write_c('(')
812        for arg in node.args:
813            write_comma()
[fa74acf]814            self.visited_args = True
[71779b2]815            self.visit(arg)
816        for keyword in node.keywords:
817            write_comma()
818            self.write_c(keyword.arg + '=')
819            self.visit(keyword.value)
[fa74acf]820        if hasattr(node, 'starargs'):
[71779b2]821            if node.starargs is not None:
822                write_comma()
823                self.write_c('*')
824                self.visit(node.starargs)
[fa74acf]825        if hasattr(node, 'kwargs'):
[71779b2]826            if node.kwargs is not None:
827                write_comma()
828                self.write_c('**')
829                self.visit(node.kwargs)
[4c87de0]830        self.write_c(')')
831        if not self.in_expr:
832            self.add_semi_colon()
[71779b2]833
[4c87de0]834    TRANSLATE_CONSTANTS = {
[6a37819]835        # python 2 uses normal name references through vist_Name
[4c87de0]836        'True': 'true',
837        'False': 'false',
838        'None': 'NULL',  # "None" will probably fail for other reasons
[6a37819]839        # python 3 uses NameConstant
840        True: 'true',
841        False: 'false',
842        None: 'NULL',  # "None" will probably fail for other reasons
[4c87de0]843        }
[6a37819]844
[71779b2]845    def visit_Name(self, node):
[4c87de0]846        translation = self.TRANSLATE_CONSTANTS.get(node.id, None)
847        if translation:
848            self.write_c(translation)
849            return
[71779b2]850        self.write_c(node.id)
[c01ed3e]851        if node.id in self.c_pointers and not self.in_subref:
[71779b2]852            self.write_c("[0]")
853        name = ""
854        sub = node.id.find("[")
[7b1dcf9]855        if sub > 0:
[71779b2]856            name = node.id[0:sub].strip()
857        else:
858            name = node.id
[c01ed3e]859        # add variable to c_vars if it ins't there yet, not an argument and not a number
860        if (name not in self.c_functions and name not in self.c_vars and
861                name not in self.c_int_vars and name not in self.arguments and
862                name not in self.c_constants and not name.isdigit()):
863            if self.in_subscript:
[d5014e4]864                self.add_c_int_var(node.id)
[937afef]865            else:
[c01ed3e]866                self.c_vars.append(node.id)
[71779b2]867
[6a37819]868    def visit_NameConstant(self, node):
869        translation = self.TRANSLATE_CONSTANTS.get(node.value, None)
870        if translation is not None:
871            self.write_c(translation)
872        else:
873            self.unsupported(node, "don't know how to translate %r"%node.value)
874
[71779b2]875    def visit_Str(self, node):
[4c87de0]876        s = node.s
[5dd7cfb]877        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
[4c87de0]878        self.write_c('"')
879        self.write_c(s)
880        self.write_c('"')
[71779b2]881
882    def visit_Bytes(self, node):
[4c87de0]883        s = node.s
[5dd7cfb]884        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
[4c87de0]885        self.write_c('"')
886        self.write_c(s)
887        self.write_c('"')
[71779b2]888
889    def visit_Num(self, node):
890        self.write_c(repr(node.n))
891
892    def visit_Tuple(self, node):
893        for idx, item in enumerate(node.elts):
894            if idx:
[c01ed3e]895                self.tuples.append(item)
[71779b2]896            else:
897                self.visit(item)
898
[5dd7cfb]899    def visit_List(self, node):
900        #self.unsupported(node)
901        #print("visiting", node)
902        #print(astor.to_source(node))
903        #print(node.elts)
904        exprs = [render_expression(item) for item in node.elts]
905        if exprs:
906            self.c_vectors.append(', '.join(exprs))
907            vec_name = "vec"  + str(len(self.c_vectors))
908            self.write_c(vec_name)
909
910    def visit_Set(self, node):
911        self.unsupported(node)
[71779b2]912
913    def visit_Dict(self, node):
[c01ed3e]914        self.unsupported(node)
[15be191]915
[71779b2]916        self.write_python('{')
917        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
918            if idx:
919                self.write_python(', ')
920            self.visit(key)
921            self.write_python(': ')
922            self.visit(value)
923        self.write_python('}')
924
[fa74acf]925    def get_special_power(self, string):
[3f9db6e]926        function_name = ''
927        is_negative_exp = False
[7b1dcf9]928        if isevaluable(str(self.current_statement)):
[3f9db6e]929            exponent = eval(string)
930            is_negative_exp = exponent < 0
931            abs_exponent = abs(exponent)
[7b1dcf9]932            if abs_exponent == 2:
[3f9db6e]933                function_name = "square"
[7b1dcf9]934            elif abs_exponent == 3:
[3f9db6e]935                function_name = "cube"
[7b1dcf9]936            elif abs_exponent == 0.5:
[3f9db6e]937                function_name = "sqrt"
[7b1dcf9]938            elif abs_exponent == 1.0/3.0:
[3f9db6e]939                function_name = "cbrt"
[7b1dcf9]940        if function_name == '':
[3f9db6e]941            function_name = "pow"
942        return function_name, is_negative_exp
943
[fa74acf]944    def translate_power(self, node):
[7b1dcf9]945        # get exponent by visiting the right hand argument.
[71779b2]946        function_name = "pow"
947        temp_statement = self.current_statement
[7b1dcf9]948        # 'visit' functions write the results to the 'current_statement' class memnber
949        # Here, a temporary variable, 'temp_statement', is used, that enables the
950        # use of the 'visit' function
[71779b2]951        self.current_statement = ''
952        self.visit(node.right)
[fa74acf]953        exponent = self.current_statement.replace(' ', '')
954        function_name, is_negative_exp = self.get_special_power(self.current_statement)
[71779b2]955        self.current_statement = temp_statement
[7b1dcf9]956        if is_negative_exp:
[fa74acf]957            self.write_c("1.0 /(")
958        self.write_c(function_name + "(")
[71779b2]959        self.visit(node.left)
[7b1dcf9]960        if function_name == "pow":
[71779b2]961            self.write_c(", ")
962            self.visit(node.right)
963        self.write_c(")")
[7b1dcf9]964        if is_negative_exp:
[3f9db6e]965            self.write_c(")")
[4c87de0]966        #self.write_c(" ")
[71779b2]967
[fa74acf]968    def translate_integer_divide(self, node):
[4c87de0]969        self.write_c("(int)((")
[71779b2]970        self.visit(node.left)
[4c87de0]971        self.write_c(")/(")
[71779b2]972        self.visit(node.right)
[4c87de0]973        self.write_c("))")
974
975    def translate_float_divide(self, node):
976        self.write_c("((double)(")
977        self.visit(node.left)
978        self.write_c(")/(double)(")
979        self.visit(node.right)
980        self.write_c("))")
[71779b2]981
982    def visit_BinOp(self, node):
[fb5c8c7]983        self.write_c("(")
[7b1dcf9]984        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
[fa74acf]985            self.translate_power(node)
[7b1dcf9]986        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
[fa74acf]987            self.translate_integer_divide(node)
[4c87de0]988        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]:
989            self.translate_float_divide(node)
[71779b2]990        else:
991            self.visit(node.left)
992            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
993            self.visit(node.right)
[fb5c8c7]994        self.write_c(")")
995
[7b1dcf9]996    # for C
[71779b2]997    def visit_BoolOp(self, node):
998        self.write_c('(')
999        for idx, value in enumerate(node.values):
1000            if idx:
1001                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
1002            self.visit(value)
1003        self.write_c(')')
1004
1005    def visit_Compare(self, node):
1006        self.write_c('(')
1007        self.visit(node.left)
1008        for op, right in zip(node.ops, node.comparators):
1009            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1010            self.visit(right)
1011        self.write_c(')')
1012
1013    def visit_UnaryOp(self, node):
1014        self.write_c('(')
1015        op = UNARYOP_SYMBOLS[type(node.op)]
1016        self.write_c(op)
1017        if op == 'not':
1018            self.write_c(' ')
1019        self.visit(node.operand)
1020        self.write_c(')')
1021
1022    def visit_Subscript(self, node):
[c01ed3e]1023        if node.value.id not in self.c_constants:
1024            if node.value.id not in self.c_pointers:
1025                self.c_pointers.append(node.value.id)
1026        self.in_subref = True
[71779b2]1027        self.visit(node.value)
[c01ed3e]1028        self.in_subref = False
[71779b2]1029        self.write_c('[')
[c01ed3e]1030        self.in_subscript = True
[71779b2]1031        self.visit(node.slice)
[c01ed3e]1032        self.in_subscript = False
[71779b2]1033        self.write_c(']')
1034
1035    def visit_Slice(self, node):
1036        if node.lower is not None:
1037            self.visit(node.lower)
1038        self.write_python(':')
1039        if node.upper is not None:
1040            self.visit(node.upper)
1041        if node.step is not None:
1042            self.write_python(':')
[fa74acf]1043            if not(isinstance(node.step, Name) and node.step.id == 'None'):
[71779b2]1044                self.visit(node.step)
1045
1046    def visit_ExtSlice(self, node):
1047        for idx, item in node.dims:
1048            if idx:
1049                self.write_python(', ')
1050            self.visit(item)
1051
1052    def visit_Yield(self, node):
[c01ed3e]1053        self.unsupported(node)
[15be191]1054
[71779b2]1055        self.write_python('yield ')
1056        self.visit(node.value)
1057
1058    def visit_Lambda(self, node):
[c01ed3e]1059        self.unsupported(node)
[15be191]1060
[71779b2]1061        self.write_python('lambda ')
1062        self.visit(node.args)
1063        self.write_python(': ')
1064        self.visit(node.body)
1065
1066    def visit_Ellipsis(self, node):
[c01ed3e]1067        self.unsupported(node)
[15be191]1068
[71779b2]1069        self.write_python('Ellipsis')
1070
1071    def generator_visit(left, right):
1072        def visit(self, node):
1073            self.write_python(left)
[937afef]1074            self.write_c(left)
[71779b2]1075            self.visit(node.elt)
1076            for comprehension in node.generators:
1077                self.visit(comprehension)
[937afef]1078            self.write_c(right)
[7b1dcf9]1079            #self.write_python(right)
[71779b2]1080        return visit
1081
1082    visit_ListComp = generator_visit('[', ']')
1083    visit_GeneratorExp = generator_visit('(', ')')
1084    visit_SetComp = generator_visit('{', '}')
1085    del generator_visit
1086
1087    def visit_DictComp(self, node):
[c01ed3e]1088        self.unsupported(node)
[15be191]1089
[71779b2]1090        self.write_python('{')
1091        self.visit(node.key)
1092        self.write_python(': ')
1093        self.visit(node.value)
1094        for comprehension in node.generators:
1095            self.visit(comprehension)
1096        self.write_python('}')
1097
1098    def visit_IfExp(self, node):
[4c87de0]1099        self.write_c('((')
[71779b2]1100        self.visit(node.test)
[4c87de0]1101        self.write_c(')?(')
1102        self.visit(node.body)
1103        self.write_c('):(')
[71779b2]1104        self.visit(node.orelse)
[4c87de0]1105        self.write_c('))')
[71779b2]1106
1107    def visit_Starred(self, node):
1108        self.write_c('*')
1109        self.visit(node.value)
1110
1111    def visit_Repr(self, node):
[7b1dcf9]1112        # CRUFT: python 2.6 only
[71779b2]1113        self.write_c('`')
1114        self.visit(node.value)
1115        self.write_python('`')
1116
1117    # Helper Nodes
1118
1119    def visit_alias(self, node):
[c01ed3e]1120        self.unsupported(node)
[15be191]1121
[71779b2]1122        self.write_python(node.name)
1123        if node.asname is not None:
1124            self.write_python(' as ' + node.asname)
1125
1126    def visit_comprehension(self, node):
[5dd7cfb]1127        self.unsupported(node)
1128
[71779b2]1129        self.write_c(' for ')
1130        self.visit(node.target)
[5dd7cfb]1131        self.write_c(' in ')
[7b1dcf9]1132        #self.write_python(' in ')
[71779b2]1133        self.visit(node.iter)
1134        if node.ifs:
1135            for if_ in node.ifs:
1136                self.write_python(' if ')
1137                self.visit(if_)
1138
1139    def visit_arguments(self, node):
1140        self.signature(node)
1141
[c01ed3e]1142    def unsupported(self, node, message=None):
1143        if hasattr(node, "value"):
1144            lineno = node.value.lineno
1145        elif hasattr(node, "iter"):
1146            lineno = node.iter.lineno
1147        else:
1148            #print(dir(node))
1149            lineno = 0
1150
1151        lineno += self.lineno_offset
1152        if self.fname:
1153            location = "%s(%d)" % (self.fname, lineno)
1154        else:
1155            location = "%d" % (self.fname, lineno)
1156        if self.current_function:
1157            location += ", function %s" % self.current_function
1158        if message is None:
1159            message = node.__class__.__name__ + " syntax not supported"
1160        raise SyntaxError("[%s] %s" % (location, message))
1161
[71779b2]1162def print_function(f=None):
1163    """
1164    Print out the code for the function
1165    """
1166    # Include some comments to see if they get printed
1167    import ast
1168    import inspect
1169    if f is not None:
1170        tree = ast.parse(inspect.getsource(f))
[fa74acf]1171        tree_source = to_source(tree)
[71779b2]1172        print(tree_source)
1173
[c01ed3e]1174def define_constant(name, value, block_size=1):
1175    # type: (str, any, int) -> str
1176    """
1177    Convert a python constant into a C constant of the same name.
1178
1179    Returns the C declaration of the constant as a string, possibly containing
1180    line feeds.  The string will not be indented.
1181
1182    Supports int, double and sequences of double.
1183    """
1184    const = "constant "  # OpenCL needs globals to be constant
1185    if isinstance(value, int):
[15be191]1186        parts = [const + "int ", name, " = ", "%d"%value, ";\n"]
[c01ed3e]1187    elif isinstance(value, float):
[15be191]1188        parts = [const + "double ", name, " = ", "%.15g"%value, ";\n"]
[c01ed3e]1189    else:
1190        try:
1191            len(value)
1192        except TypeError:
1193            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1194        # extend constant arrays to a multiple of 4; not sure if this
1195        # is necessary, but some OpenCL targets broke if the number
1196        # of parameters in the parameter table was not a multiple of 4,
1197        # so do it for all constant arrays to be safe.
1198        if len(value)%block_size != 0:
1199            value = list(value) + [0.]*(block_size - len(value)%block_size)
1200        elements = ["%.15g"%v for v in value]
1201        parts = [const + "double ", name, "[]", " = ",
[15be191]1202                 "{\n   ", ", ".join(elements), "\n};\n"]
[c01ed3e]1203
1204    return "".join(parts)
1205
1206
1207# Modified from the following:
1208#
1209#    http://code.activestate.com/recipes/578272-topological-sort/
1210#    Copyright (C) 2012 Sam Denton
1211#    License: MIT
1212def ordered_dag(dag):
1213    # type: (Dict[T, Set[T]]) -> Iterator[T]
1214    """
1215    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1216    in order such that every key occurs after the keys it depends upon.
1217
1218    This is an iterator not a sequence.  To reverse it use::
1219
1220        reversed(tuple(ordered_dag(dag)))
1221
1222    Raise an error if there are any cycles.
1223
1224    Keys are arbitrary hashable values.
1225    """
1226    # Local import to make the function stand-alone, and easier to borrow
1227    from functools import reduce
1228
1229    dag = dag.copy()
1230
1231    # make leaves depend on the empty set
1232    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1233    dag.update({node: set() for node in leaves})
1234    while True:
1235        leaves = set(node for node, links in dag.items() if not links)
1236        if not leaves:
1237            break
1238        for node in leaves:
1239            yield node
1240        dag = {node: (links-leaves)
1241               for node, links in dag.items() if node not in leaves}
1242    if dag:
1243        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1244                         % ", ".join(str(node) for node in dag.keys()))
1245
[4c87de0]1246import re
1247PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n')
1248SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n'
1249PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n')
1250SUBST_STR = r'printf("\g<template>\n")'
[c01ed3e]1251def translate(functions, constants=None):
[4c87de0]1252    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str]
[c01ed3e]1253    """
[4c87de0]1254    Convert a list of functions to a list of C code strings.
1255
[d5014e4]1256    Returns list of corresponding code snippets (with trailing lines in
1257    each block) and a list of warnings generated by the translator.
1258
[4c87de0]1259    A function is given by the tuple (source, filename, line number).
1260
1261    Global constants are given in a dictionary of {name: value}.  The
1262    constants are used for name space resolution and type inferencing.
1263    Constants are not translated by this code. Instead, call
1264    :func:`define_constant` with name and value, and maybe block_size
1265    if arrays need to be padded to the next block boundary.
1266
1267    Function prototypes are not generated. Use :func:`ordered_dag`
1268    to list the functions in reverse order of dependency before calling
1269    translate. [Maybe a future revision will return the function prototypes
1270    so that a suitable "*.h" file can be generated.
[c01ed3e]1271    """
[1ddb794]1272    snippets = []
[4339764]1273    warnings = []
[c01ed3e]1274    for source, fname, lineno in functions:
[15be191]1275        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\'))
[1ddb794]1276        snippets.append(line_directive)
[4c87de0]1277        # Replace simple print function calls with printf statements
1278        source = PRINT_ARGS.sub(SUBST_ARGS, source)
1279        source = PRINT_STR.sub(SUBST_STR, source)
[98a4f14]1280        tree = ast.parse(source)
[d5014e4]1281        generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
1282        generator.visit(tree)
1283        c_code = "".join(generator.c_proc)
[1ddb794]1284        snippets.append(c_code)
[d5014e4]1285        warnings.extend(generator.warnings)
1286    return snippets, warnings
[71779b2]1287
[4339764]1288
[6f91c91]1289C_HEADER_LINENO = getframeinfo(currentframe()).lineno + 2
[0bd0877]1290C_HEADER = """
[6f91c91]1291#line %d "%s"
[0bd0877]1292#include <stdio.h>
1293#include <stdbool.h>
1294#include <math.h>
1295#define constant const
1296double square(double x) { return x*x; }
1297double cube(double x) { return x*x*x; }
1298double polyval(constant double *coef, double x, int N)
1299{
1300    int i = 0;
1301    double ans = coef[0];
1302
1303    while (i < N) {
1304        ans = ans * x + coef[i++];
1305    }
1306
1307    return ans;
1308}
1309"""
[d5014e4]1310
1311USAGE = """\
1312Usage: python py2c.py <infile> [<outfile>]
1313
1314if outfile is omitted, output file is '<infile>.c'
1315"""
[71779b2]1316
[7b1dcf9]1317def main():
[71779b2]1318    import os
[d7f33e5]1319    #print("Parsing...using Python" + sys.version)
[7b1dcf9]1320    if len(sys.argv) == 1:
[d5014e4]1321        print(USAGE)
[7b1dcf9]1322        return
1323
1324    fname_in = sys.argv[1]
1325    if len(sys.argv) == 2:
1326        fname_base = os.path.splitext(fname_in)[0]
1327        fname_out = str(fname_base) + '.c'
1328    else:
1329        fname_out = sys.argv[2]
1330
1331    with open(fname_in, "r") as python_file:
1332        code = python_file.read()
[c01ed3e]1333    name = "gauss"
1334    code = (code
1335            .replace(name+'.n', 'GAUSS_N')
1336            .replace(name+'.z', 'GAUSS_Z')
[4c87de0]1337            .replace(name+'.w', 'GAUSS_W')
1338            .replace('if __name__ == "__main__"', "def main()")
[0bd0877]1339           )
[7b1dcf9]1340
[d5014e4]1341    translation, warnings = translate([(code, fname_in, 1)])
1342    c_code = "".join(translation)
[4c87de0]1343    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])")
[7b1dcf9]1344
1345    with open(fname_out, "w") as file_out:
[6f91c91]1346        file_out.write(C_HEADER%(C_HEADER_LINENO, __file__))
[4c87de0]1347        file_out.write(c_code)
[d5014e4]1348
[4339764]1349    if warnings:
[d5014e4]1350        print("\n".join(warnings))
[d7f33e5]1351    #print("...Done")
[7b1dcf9]1352
1353if __name__ == "__main__":
1354    main()
Note: See TracBrowser for help on using the repository browser.