source: sasmodels/sasmodels/py2c.py @ d7f33e5

Last change on this file since d7f33e5 was d7f33e5, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

quiet py2c main

  • Property mode set to 100644
File size: 44.2 KB
Line 
1r"""
2py2c
3~~~~
4
5Convert simple numeric python code into C code.
6
7This code is intended to translate direct algorithms for scientific code
8(mostly if statements and for loops operating on double precision values)
9into C code. Unlike projects like numba, cython, pypy and nuitka, the
10:func:`translate` function returns the corresponding C which can then be
11compiled with tinycc or sent to the GPU using CUDA or OpenCL.
12
13There is special handling certain constructs, such as *for i in range* and
14small integer powers.
15
16**TODO: make a nice list of supported constructs***
17
18Imports are not supported, but they are at least ignored so that properly
19constructed code can be run via python or translated to C without change.
20
21Most other python constructs are **not** supported:
22* classes
23* builtin types (dict, set, list)
24* exceptions
25* with context
26* del
27* yield
28* async
29* list slicing
30* multiple return values
31* "is/is not", "in/not in" conditionals
32
33There is limited support for list and list comprehensions, so long as they
34can be represented by a fixed array whose size is known at compile time, and
35they are small enough to be stored on the stack.
36
37Variables definition in C
38-------------------------
39Defining variables within the translate function is a bit of a guess work,
40using following rules:
41*   By default, a variable is a 'double'.
42*   Variable in a for loop is an int.
43*   Variable that is references with brackets is an array of doubles. The
44    variable within the brackets is integer. For example, in the
45    reference 'var1[var2]', var1 is a double array, and var2 is an integer.
46*   Assignment to an argument makes that argument an array, and the index
47    in that assignment is 0.
48    For example, the following python code::
49        def func(arg1, arg2):
50            arg2 = 17.
51    is translated to the following C code::
52        double func(double arg1)
53        {
54            arg2[0] = 17.0;
55        }
56    For example, the following python code is translated to the
57    following C code::
58
59        def func(arg1, arg2):          double func(double arg1) {
60            arg2 = 17.                      arg2[0] = 17.0;
61                                        }
62*   All functions are defined as double, even if there is no
63    return statement.
64
65Debugging
66---------
67
68*print* is partially supported using a simple regular expression. This
69requires a stylized form. Be sure to use print as a function instead of
70the print statement. If you are including substition variables, use the
71% string substitution style. Include parentheses around the substitution
72tuple, even if there is only one item; do not include the final comma even
73if it is a single item (yes, it won't be a tuple, but it makes the regexp
74much simpler). Keep the item on a single line. Here are three forms that work::
75
76    print("x") => printf("x\n");
77    print("x %g"%(a)) => printf("x %g\n", a);
78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c);
79
80You can generate *main* using the *if __name__ == "__main__":* construct.
81This does a simple substitution with "def main():" before translation and
82a substitution with "int main(int argc, double *argv[])" after translation.
83The result is that the content of the *if* block becomes the content of *main*.
84Along with the print statement, you can run and test a translation standalone
85using::
86
87    python py2c.py source.py
88    cc source.c
89    ./a.out
90
91Known issues
92------------
93The following constructs may cause problems:
94
95* implicit arrays: possible namespace collision for variable "vec#"
96* swap fails: "x,y = y,x" will set x==y
97* top-level statements: code outside a function body causes errors
98* line number skew: each statement should be tagged with its own #line
99  to avoid skew as comments are skipped and loop bodies are wrapped with
100  braces, etc.
101
102References
103----------
104
105Based on a variant of codegen.py:
106
107    https://github.com/andreif/codegen
108    :copyright: Copyright 2008 by Armin Ronacher.
109    :license: BSD.
110"""
111
112# Update Notes
113# ============
114# 11/22/2017, O.E.   Each 'visit_*' method is to build a C statement string. It
115#                     shold insert 4 blanks per indentation level.
116#                     The 'body' method will combine all the strings, by adding
117#                     the 'current_statement' to the c_proc string list
118#    11/2017, OE: variables, argument definition implemented.
119#    Note: An argument is considered an array if it is the target of an
120#         assignment. In that case it is translated to <var>[0]
121# 11/27/2017, OE: 'pow' basicly working
122#   /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
123#   /12/2017, OE: Power function, including special cases of
124#                 square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
125#                 translate_power, called from visit_BinOp
126# 12/07/2017, OE: Translation of integer division, '\\' in python, implemented
127#                 in translate_integer_divide, called from visit_BinOp
128# 12/07/2017, OE: C variable definition handled in 'define_c_vars'
129#               : Python integer division, '//', translated to C in
130#                 'translate_integer_divide'
131# 12/15/2017, OE: Precedence maintained by writing opening and closing
132#                 parenthesesm '(',')', in procedure 'visit_BinOp'.
133# 12/18/2017, OE: Added call to 'add_current_line()' at the beginning
134#                 of visit_Return
135# 2018-01-03, PK: Update interface for use in sasmodels
136# 2018-01-03, PK: support "expr if cond else expr" syntax
137# 2018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y))
138# 2018-01-03, PK: True/False => true/false
139# 2018-01-03, PK: f(x) was introducing an extra semicolon
140# 2018-01-03, PK: simplistic print function, for debugging
141# 2018-01-03, PK: while expr: ... => while (expr) { ... }
142
143from __future__ import print_function
144
145import sys
146import ast
147from ast import NodeVisitor
148
149BINOP_SYMBOLS = {}
150BINOP_SYMBOLS[ast.Add] = '+'
151BINOP_SYMBOLS[ast.Sub] = '-'
152BINOP_SYMBOLS[ast.Mult] = '*'
153BINOP_SYMBOLS[ast.Div] = '/'
154BINOP_SYMBOLS[ast.Mod] = '%'
155BINOP_SYMBOLS[ast.Pow] = '**'
156BINOP_SYMBOLS[ast.LShift] = '<<'
157BINOP_SYMBOLS[ast.RShift] = '>>'
158BINOP_SYMBOLS[ast.BitOr] = '|'
159BINOP_SYMBOLS[ast.BitXor] = '^'
160BINOP_SYMBOLS[ast.BitAnd] = '&'
161BINOP_SYMBOLS[ast.FloorDiv] = '//'
162
163BOOLOP_SYMBOLS = {}
164BOOLOP_SYMBOLS[ast.And] = '&&'
165BOOLOP_SYMBOLS[ast.Or] = '||'
166
167CMPOP_SYMBOLS = {}
168CMPOP_SYMBOLS[ast.Eq] = '=='
169CMPOP_SYMBOLS[ast.NotEq] = '!='
170CMPOP_SYMBOLS[ast.Lt] = '<'
171CMPOP_SYMBOLS[ast.LtE] = '<='
172CMPOP_SYMBOLS[ast.Gt] = '>'
173CMPOP_SYMBOLS[ast.GtE] = '>='
174CMPOP_SYMBOLS[ast.Is] = 'is'
175CMPOP_SYMBOLS[ast.IsNot] = 'is not'
176CMPOP_SYMBOLS[ast.In] = 'in'
177CMPOP_SYMBOLS[ast.NotIn] = 'not in'
178
179UNARYOP_SYMBOLS = {}
180UNARYOP_SYMBOLS[ast.Invert] = '~'
181UNARYOP_SYMBOLS[ast.Not] = 'not'
182UNARYOP_SYMBOLS[ast.UAdd] = '+'
183UNARYOP_SYMBOLS[ast.USub] = '-'
184
185
186def to_source(tree, constants=None, fname=None, lineno=0):
187    """
188    This function can convert a syntax tree into C sourcecode.
189    """
190    generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
191    generator.visit(tree)
192    c_code = "".join(generator.c_proc)
193    return c_code
194
195def isevaluable(s):
196    try:
197        eval(s)
198        return True
199    except Exception:
200        return False
201
202class SourceGenerator(NodeVisitor):
203    """This visitor is able to transform a well formed syntax tree into python
204    sourcecode.  For more details have a look at the docstring of the
205    `node_to_source` function.
206    """
207
208    def __init__(self, indent_with="    ", add_line_information=False,
209                 constants=None, fname=None, lineno=0):
210        self.result = []
211        self.indent_with = indent_with
212        self.add_line_information = add_line_information
213        self.indentation = 0
214        self.new_lines = 0
215
216        # for C
217        self.c_proc = []
218        self.signature_line = 0
219        self.arguments = []
220        self.current_function = ""
221        self.fname = fname
222        self.lineno_offset = lineno
223        self.warnings = []
224        self.statements = []
225        self.current_statement = ""
226        # TODO: use set rather than list for c_vars, ...
227        self.c_vars = []
228        self.c_int_vars = []
229        self.c_pointers = []
230        self.c_dcl_pointers = []
231        self.c_functions = []
232        self.c_vectors = []
233        self.c_constants = constants if constants is not None else {}
234        self.in_expr = False
235        self.in_subref = False
236        self.in_subscript = False
237        self.tuples = []
238        self.required_functions = []
239        self.is_sequence = False
240        self.visited_args = False
241
242    def write_python(self, x):
243        if self.new_lines:
244            if self.result:
245                self.result.append('\n' * self.new_lines)
246            self.result.append(self.indent_with * self.indentation)
247            self.new_lines = 0
248        self.result.append(x)
249
250    def write_c(self, statement):
251        # TODO: build up as a list rather than adding to string
252        self.current_statement += statement
253
254    def add_c_line(self, line):
255        indentation = self.indent_with * self.indentation
256        self.c_proc.append("".join((indentation, line, "\n")))
257
258    def add_current_line(self):
259        if self.current_statement:
260            self.add_c_line(self.current_statement)
261            self.current_statement = ''
262
263    def add_unique_var(self, new_var):
264        if new_var not in self.c_vars:
265            self.c_vars.append(str(new_var))
266
267    def write_sincos(self, node):
268        angle = str(node.args[0].id)
269        self.write_c(node.args[1].id + " = sin(" + angle + ");")
270        self.add_current_line()
271        self.write_c(node.args[2].id + " = cos(" + angle + ");")
272        self.add_current_line()
273        for arg in node.args:
274            self.add_unique_var(arg.id)
275
276    def newline(self, node=None, extra=0):
277        self.new_lines = max(self.new_lines, 1 + extra)
278        if node is not None and self.add_line_information:
279            self.write_c('// line: %s' % node.lineno)
280            self.new_lines = 1
281        if self.current_statement:
282            self.statements.append(self.current_statement)
283            self.current_statement = ''
284
285    def body(self, statements):
286        if self.current_statement:
287            self.add_current_line()
288        self.new_line = True
289        self.indentation += 1
290        for stmt in statements:
291            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
292            #    target_name = stmt.targets[0].id # target name needed for debug only
293            self.visit(stmt)
294        self.add_current_line() # just for breaking point. to be deleted.
295        self.indentation -= 1
296
297    def body_or_else(self, node):
298        self.body(node.body)
299        if node.orelse:
300            self.unsupported(node, "for...else/while...else not supported")
301
302            self.newline()
303            self.write_c('else:')
304            self.body(node.orelse)
305
306    def signature(self, node):
307        want_comma = []
308        def write_comma():
309            if want_comma:
310                self.write_c(', ')
311            else:
312                want_comma.append(True)
313
314        # for C
315        for arg in node.args:
316            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
317            try:
318                arg_name = arg.arg
319            except AttributeError:
320                arg_name = arg.id
321            self.arguments.append(arg_name)
322
323        padding = [None] *(len(node.args) - len(node.defaults))
324        for arg, default in zip(node.args, padding + node.defaults):
325            if default is not None:
326                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
327                try:
328                    arg_name = arg.arg
329                except AttributeError:
330                    arg_name = arg.id
331                w_str = ("Default Parameters are unknown to C: '%s = %s"
332                         % (arg_name, str(default.n)))
333                self.warnings.append(w_str)
334
335    def decorators(self, node):
336        for decorator in node.decorator_list:
337            self.newline(decorator)
338            self.write_python('@')
339            self.visit(decorator)
340
341    # Statements
342
343    def visit_Assert(self, node):
344        self.newline(node)
345        self.write_c('assert ')
346        self.visit(node.test)
347        if node.msg is not None:
348            self.write_python(', ')
349            self.visit(node.msg)
350
351    def define_c_vars(self, target):
352        if hasattr(target, 'id'):
353        # a variable is considered an array if it apears in the agrument list
354        # and being assigned to. For example, the variable p in the following
355        # sniplet is a pointer, while q is not
356        # def somefunc(p, q):
357        #  p = q + 1
358        #  return
359        #
360            if target.id not in self.c_vars:
361                if target.id in self.arguments:
362                    idx = self.arguments.index(target.id)
363                    new_target = self.arguments[idx] + "[0]"
364                    if new_target not in self.c_pointers:
365                        target.id = new_target
366                        self.c_pointers.append(self.arguments[idx])
367                else:
368                    self.c_vars.append(target.id)
369
370    def add_semi_colon(self):
371        #semi_pos = self.current_statement.find(';')
372        #if semi_pos >= 0:
373        #    self.current_statement = self.current_statement.replace(';', '')
374        self.write_c(';')
375
376    def visit_Assign(self, node):
377        self.add_current_line()
378        self.in_expr = True
379        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
380            if idx:
381                self.write_c(' = ')
382            self.define_c_vars(target)
383            self.visit(target)
384        # Capture assigned tuple names, if any
385        targets = self.tuples[:]
386        del self.tuples[:]
387        self.write_c(' = ')
388        self.is_sequence = False
389        self.visited_args = False
390        self.visit(node.value)
391        self.add_semi_colon()
392        self.add_current_line()
393        # Assign tuples to tuples, if any
394        # TODO: doesn't handle swap:  a,b = b,a
395        for target, item in zip(targets, self.tuples):
396            self.visit(target)
397            self.write_c(' = ')
398            self.visit(item)
399            self.add_semi_colon()
400            self.add_current_line()
401        if self.is_sequence and not self.visited_args:
402            for target in node.targets:
403                if hasattr(target, 'id'):
404                    if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
405                        if target.id not in self.c_dcl_pointers:
406                            self.c_dcl_pointers.append(target.id)
407                            if target.id in self.c_vars:
408                                self.c_vars.remove(target.id)
409        self.current_statement = ''
410        self.in_expr = False
411
412    def visit_AugAssign(self, node):
413        if node.target.id not in self.c_vars:
414            if node.target.id not in self.arguments:
415                self.c_vars.append(node.target.id)
416        self.in_expr = True
417        self.visit(node.target)
418        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
419        self.visit(node.value)
420        self.add_semi_colon()
421        self.in_expr = False
422        self.add_current_line()
423
424    def visit_ImportFrom(self, node):
425        return  # import ignored
426        self.newline(node)
427        self.write_python('from %s%s import ' %('.' * node.level, node.module))
428        for idx, item in enumerate(node.names):
429            if idx:
430                self.write_python(', ')
431            self.write_python(item)
432
433    def visit_Import(self, node):
434        return  # import ignored
435        self.newline(node)
436        for item in node.names:
437            self.write_python('import ')
438            self.visit(item)
439
440    def visit_Expr(self, node):
441        #self.in_expr = True
442        self.newline(node)
443        self.generic_visit(node)
444        #self.in_expr = False
445
446    def write_c_pointers(self, start_var):
447        if self.c_dcl_pointers:
448            var_list = []
449            for c_ptr in self.c_dcl_pointers:
450                if c_ptr not in self.arguments:
451                    var_list.append("*" + c_ptr)
452                if c_ptr in self.c_vars:
453                    self.c_vars.remove(c_ptr)
454            if var_list:
455                c_dcl = "    double " + ", ".join(var_list) + ";\n"
456                self.c_proc.insert(start_var, c_dcl)
457                start_var += 1
458        return start_var
459
460    def insert_c_vars(self, start_var):
461        have_decls = False
462        start_var = self.write_c_pointers(start_var)
463        if self.c_int_vars:
464            for var in self.c_int_vars:
465                if var in self.c_vars:
466                    self.c_vars.remove(var)
467            decls = ", ".join(self.c_int_vars)
468            self.c_proc.insert(start_var, "    int " + decls + ";\n")
469            have_decls = True
470            start_var += 1
471
472        if self.c_vars:
473            decls = ", ".join(self.c_vars)
474            self.c_proc.insert(start_var, "    double " + decls + ";\n")
475            have_decls = True
476            start_var += 1
477
478        if self.c_vectors:
479            for vec_number, vec_value  in enumerate(self.c_vectors):
480                name = "vec" + str(vec_number + 1)
481                decl = "    double " + name + "[] = {" + vec_value + "};"
482                self.c_proc.insert(start_var, decl + "\n")
483                start_var += 1
484
485        del self.c_vars[:]
486        del self.c_int_vars[:]
487        del self.c_vectors[:]
488        del self.c_pointers[:]
489        del self.c_dcl_pointers[:]
490        if have_decls:
491            self.c_proc.insert(start_var, "\n")
492
493    def insert_signature(self):
494        arg_decls = []
495        for arg in self.arguments:
496            decl = "double " + arg
497            if arg in self.c_pointers:
498                decl += "[]"
499            arg_decls.append(decl)
500        args_str = ", ".join(arg_decls)
501        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
502        if self.signature_line >= 0:
503            self.c_proc.insert(self.signature_line, method_sig)
504
505    def visit_FunctionDef(self, node):
506        if self.current_function:
507            self.unsupported(node, "function within a function")
508        self.current_function = node.name
509
510        self.newline(extra=1)
511        self.decorators(node)
512        self.newline(node)
513        self.arguments = []
514        self.visit(node.args)
515        # for C
516        self.signature_line = len(self.c_proc)
517        self.add_c_line("\n{")
518        start_vars = len(self.c_proc) + 1
519        self.body(node.body)
520        self.add_c_line("}\n")
521        self.insert_signature()
522        self.insert_c_vars(start_vars)
523        del self.c_pointers[:]
524        self.current_function = ""
525
526    def visit_ClassDef(self, node):
527        have_args = []
528        def paren_or_comma():
529            if have_args:
530                self.write_python(', ')
531            else:
532                have_args.append(True)
533                self.write_python('(')
534
535        self.newline(extra=2)
536        self.decorators(node)
537        self.newline(node)
538        self.write_python('class %s' % node.name)
539        for base in node.bases:
540            paren_or_comma()
541            self.visit(base)
542        # CRUFT: python 2.6 does not have "keywords" attribute
543        if hasattr(node, 'keywords'):
544            for keyword in node.keywords:
545                paren_or_comma()
546                self.write_python(keyword.arg + '=')
547                self.visit(keyword.value)
548            if node.starargs is not None:
549                paren_or_comma()
550                self.write_python('*')
551                self.visit(node.starargs)
552            if node.kwargs is not None:
553                paren_or_comma()
554                self.write_python('**')
555                self.visit(node.kwargs)
556        self.write_python(have_args and '):' or ':')
557        self.body(node.body)
558
559    def visit_If(self, node):
560
561        self.write_c('if ')
562        self.in_expr = True
563        self.visit(node.test)
564        self.in_expr = False
565        self.write_c(' {')
566        self.body(node.body)
567        self.add_c_line('}')
568        while True:
569            else_ = node.orelse
570            if len(else_) == 0:
571                break
572            #elif hasattr(else_, 'orelse'):
573            elif len(else_) == 1 and isinstance(else_[0], ast.If):
574                node = else_[0]
575                #self.newline()
576                self.write_c('else if ')
577                self.in_expr = True
578                self.visit(node.test)
579                self.in_expr = False
580                self.write_c(' {')
581                self.body(node.body)
582                self.add_current_line()
583                self.add_c_line('}')
584                #break
585            else:
586                self.newline()
587                self.write_c('else {')
588                self.body(node.body)
589                self.add_c_line('}')
590                break
591
592    def get_for_range(self, node):
593        stop = ""
594        start = '0'
595        step = '1'
596        for_args = []
597        temp_statement = self.current_statement
598        self.current_statement = ''
599        for arg in node.iter.args:
600            self.visit(arg)
601            for_args.append(self.current_statement)
602            self.current_statement = ''
603        self.current_statement = temp_statement
604        if len(for_args) == 1:
605            stop = for_args[0]
606        elif len(for_args) == 2:
607            start = for_args[0]
608            stop = for_args[1]
609        elif len(for_args) == 3:
610            start = for_args[0]
611            stop = for_args[1]
612            start = for_args[2]
613        else:
614            raise("Ilegal for loop parameters")
615        return start, stop, step
616
617    def visit_For(self, node):
618        # node: for iterator is stored in node.target.
619        # Iterator name is in node.target.id.
620        self.add_current_line()
621        fForDone = False
622        self.current_statement = ''
623        if hasattr(node.iter, 'func'):
624            if hasattr(node.iter.func, 'id'):
625                if node.iter.func.id == 'range':
626                    self.visit(node.target)
627                    iterator = self.current_statement
628                    self.current_statement = ''
629                    if iterator not in self.c_int_vars:
630                        self.c_int_vars.append(iterator)
631                    start, stop, step = self.get_for_range(node)
632                    self.write_c("for (" + iterator + "=" + str(start) +
633                                 " ; " + iterator + " < " + str(stop) +
634                                 " ; " + iterator + " += " + str(step) + ") {")
635                    self.body_or_else(node)
636                    self.write_c("}")
637                    fForDone = True
638        if not fForDone:
639            # Generate the statement that is causing the error
640            self.current_statement = ''
641            self.write_c('for ')
642            self.visit(node.target)
643            self.write_c(' in ')
644            self.visit(node.iter)
645            self.write_c(':')
646            # report the error
647            self.unsupported(node, "unsupported " + self.current_statement)
648
649    def visit_While(self, node):
650        self.newline(node)
651        self.write_c('while ')
652        self.visit(node.test)
653        self.write_c(' {')
654        self.body_or_else(node)
655        self.write_c('}')
656        self.add_current_line()
657
658    def visit_With(self, node):
659        self.unsupported(node)
660
661        self.newline(node)
662        self.write_python('with ')
663        self.visit(node.context_expr)
664        if node.optional_vars is not None:
665            self.write_python(' as ')
666            self.visit(node.optional_vars)
667        self.write_python(':')
668        self.body(node.body)
669
670    def visit_Pass(self, node):
671        self.newline(node)
672        #self.write_python('pass')
673
674    def visit_Print(self, node):
675        self.unsupported(node)
676
677        # CRUFT: python 2.6 only
678        self.newline(node)
679        self.write_c('print ')
680        want_comma = False
681        if node.dest is not None:
682            self.write_c(' >> ')
683            self.visit(node.dest)
684            want_comma = True
685        for value in node.values:
686            if want_comma:
687                self.write_c(', ')
688            self.visit(value)
689            want_comma = True
690        if not node.nl:
691            self.write_c(',')
692
693    def visit_Delete(self, node):
694        self.unsupported(node)
695
696        self.newline(node)
697        self.write_python('del ')
698        for idx, target in enumerate(node):
699            if idx:
700                self.write_python(', ')
701            self.visit(target)
702
703    def visit_TryExcept(self, node):
704        self.unsupported(node)
705
706        self.newline(node)
707        self.write_python('try:')
708        self.body(node.body)
709        for handler in node.handlers:
710            self.visit(handler)
711
712    def visit_TryFinally(self, node):
713        self.unsupported(node)
714
715        self.newline(node)
716        self.write_python('try:')
717        self.body(node.body)
718        self.newline(node)
719        self.write_python('finally:')
720        self.body(node.finalbody)
721
722    def visit_Global(self, node):
723        self.unsupported(node)
724
725        self.newline(node)
726        self.write_python('global ' + ', '.join(node.names))
727
728    def visit_Nonlocal(self, node):
729        self.newline(node)
730        self.write_python('nonlocal ' + ', '.join(node.names))
731
732    def visit_Return(self, node):
733        self.add_current_line()
734        self.in_expr = True
735        if node.value is None:
736            self.write_c('return')
737        else:
738            self.write_c('return(')
739            self.visit(node.value)
740        self.write_c(')')
741        self.add_semi_colon()
742        self.in_expr = False
743        self.add_c_line(self.current_statement)
744        self.current_statement = ''
745
746    def visit_Break(self, node):
747        self.newline(node)
748        self.write_c('break')
749
750    def visit_Continue(self, node):
751        self.newline(node)
752        self.write_c('continue')
753
754    def visit_Raise(self, node):
755        self.unsupported(node)
756
757        # CRUFT: Python 2.6 / 3.0 compatibility
758        self.newline(node)
759        self.write_python('raise')
760        if hasattr(node, 'exc') and node.exc is not None:
761            self.write_python(' ')
762            self.visit(node.exc)
763            if node.cause is not None:
764                self.write_python(' from ')
765                self.visit(node.cause)
766        elif hasattr(node, 'type') and node.type is not None:
767            self.visit(node.type)
768            if node.inst is not None:
769                self.write_python(', ')
770                self.visit(node.inst)
771            if node.tback is not None:
772                self.write_python(', ')
773                self.visit(node.tback)
774
775    # Expressions
776
777    def visit_Attribute(self, node):
778        self.unsupported(node, "attribute reference a.b not supported")
779
780        self.visit(node.value)
781        self.write_python('.' + node.attr)
782
783    def visit_Call(self, node):
784        want_comma = []
785        def write_comma():
786            if want_comma:
787                self.write_c(', ')
788            else:
789                want_comma.append(True)
790        if hasattr(node.func, 'id'):
791            if node.func.id not in self.c_functions:
792                self.c_functions.append(node.func.id)
793            if node.func.id == 'abs':
794                self.write_c("fabs ")
795            elif node.func.id == 'int':
796                self.write_c('(int) ')
797            elif node.func.id == "SINCOS":
798                self.write_sincos(node)
799                return
800            else:
801                self.visit(node.func)
802        else:
803            self.visit(node.func)
804        self.write_c('(')
805        for arg in node.args:
806            write_comma()
807            self.visited_args = True
808            self.visit(arg)
809        for keyword in node.keywords:
810            write_comma()
811            self.write_c(keyword.arg + '=')
812            self.visit(keyword.value)
813        if hasattr(node, 'starargs'):
814            if node.starargs is not None:
815                write_comma()
816                self.write_c('*')
817                self.visit(node.starargs)
818        if hasattr(node, 'kwargs'):
819            if node.kwargs is not None:
820                write_comma()
821                self.write_c('**')
822                self.visit(node.kwargs)
823        self.write_c(')')
824        if not self.in_expr:
825            self.add_semi_colon()
826
827    TRANSLATE_CONSTANTS = {
828        # python 2 uses normal name references through vist_Name
829        'True': 'true',
830        'False': 'false',
831        'None': 'NULL',  # "None" will probably fail for other reasons
832        # python 3 uses NameConstant
833        True: 'true',
834        False: 'false',
835        None: 'NULL',  # "None" will probably fail for other reasons
836        }
837
838    def visit_Name(self, node):
839        translation = self.TRANSLATE_CONSTANTS.get(node.id, None)
840        if translation:
841            self.write_c(translation)
842            return
843        self.write_c(node.id)
844        if node.id in self.c_pointers and not self.in_subref:
845            self.write_c("[0]")
846        name = ""
847        sub = node.id.find("[")
848        if sub > 0:
849            name = node.id[0:sub].strip()
850        else:
851            name = node.id
852        # add variable to c_vars if it ins't there yet, not an argument and not a number
853        if (name not in self.c_functions and name not in self.c_vars and
854                name not in self.c_int_vars and name not in self.arguments and
855                name not in self.c_constants and not name.isdigit()):
856            if self.in_subscript:
857                self.c_int_vars.append(node.id)
858            else:
859                self.c_vars.append(node.id)
860
861    def visit_NameConstant(self, node):
862        translation = self.TRANSLATE_CONSTANTS.get(node.value, None)
863        if translation is not None:
864            self.write_c(translation)
865        else:
866            self.unsupported(node, "don't know how to translate %r"%node.value)
867
868    def visit_Str(self, node):
869        s = node.s
870        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
871        self.write_c('"')
872        self.write_c(s)
873        self.write_c('"')
874
875    def visit_Bytes(self, node):
876        s = node.s
877        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
878        self.write_c('"')
879        self.write_c(s)
880        self.write_c('"')
881
882    def visit_Num(self, node):
883        self.write_c(repr(node.n))
884
885    def visit_Tuple(self, node):
886        for idx, item in enumerate(node.elts):
887            if idx:
888                self.tuples.append(item)
889            else:
890                self.visit(item)
891
892    def sequence_visit(left, right):
893        def visit(self, node):
894            self.is_sequence = True
895            s = ""
896            for idx, item in enumerate(node.elts):
897                if idx > 0 and s:
898                    s += ', '
899                if hasattr(item, 'id'):
900                    s += item.id
901                elif hasattr(item, 'n'):
902                    s += str(item.n)
903            if s:
904                self.c_vectors.append(s)
905                vec_name = "vec"  + str(len(self.c_vectors))
906                self.write_c(vec_name)
907        return visit
908
909    visit_List = sequence_visit('[', ']')
910    visit_Set = sequence_visit('{', '}')
911    del sequence_visit
912
913    def visit_Dict(self, node):
914        self.unsupported(node)
915
916        self.write_python('{')
917        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
918            if idx:
919                self.write_python(', ')
920            self.visit(key)
921            self.write_python(': ')
922            self.visit(value)
923        self.write_python('}')
924
925    def get_special_power(self, string):
926        function_name = ''
927        is_negative_exp = False
928        if isevaluable(str(self.current_statement)):
929            exponent = eval(string)
930            is_negative_exp = exponent < 0
931            abs_exponent = abs(exponent)
932            if abs_exponent == 2:
933                function_name = "square"
934            elif abs_exponent == 3:
935                function_name = "cube"
936            elif abs_exponent == 0.5:
937                function_name = "sqrt"
938            elif abs_exponent == 1.0/3.0:
939                function_name = "cbrt"
940        if function_name == '':
941            function_name = "pow"
942        return function_name, is_negative_exp
943
944    def translate_power(self, node):
945        # get exponent by visiting the right hand argument.
946        function_name = "pow"
947        temp_statement = self.current_statement
948        # 'visit' functions write the results to the 'current_statement' class memnber
949        # Here, a temporary variable, 'temp_statement', is used, that enables the
950        # use of the 'visit' function
951        self.current_statement = ''
952        self.visit(node.right)
953        exponent = self.current_statement.replace(' ', '')
954        function_name, is_negative_exp = self.get_special_power(self.current_statement)
955        self.current_statement = temp_statement
956        if is_negative_exp:
957            self.write_c("1.0 /(")
958        self.write_c(function_name + "(")
959        self.visit(node.left)
960        if function_name == "pow":
961            self.write_c(", ")
962            self.visit(node.right)
963        self.write_c(")")
964        if is_negative_exp:
965            self.write_c(")")
966        #self.write_c(" ")
967
968    def translate_integer_divide(self, node):
969        self.write_c("(int)((")
970        self.visit(node.left)
971        self.write_c(")/(")
972        self.visit(node.right)
973        self.write_c("))")
974
975    def translate_float_divide(self, node):
976        self.write_c("((double)(")
977        self.visit(node.left)
978        self.write_c(")/(double)(")
979        self.visit(node.right)
980        self.write_c("))")
981
982    def visit_BinOp(self, node):
983        self.write_c("(")
984        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
985            self.translate_power(node)
986        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
987            self.translate_integer_divide(node)
988        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]:
989            self.translate_float_divide(node)
990        else:
991            self.visit(node.left)
992            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
993            self.visit(node.right)
994        self.write_c(")")
995
996    # for C
997    def visit_BoolOp(self, node):
998        self.write_c('(')
999        for idx, value in enumerate(node.values):
1000            if idx:
1001                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
1002            self.visit(value)
1003        self.write_c(')')
1004
1005    def visit_Compare(self, node):
1006        self.write_c('(')
1007        self.visit(node.left)
1008        for op, right in zip(node.ops, node.comparators):
1009            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1010            self.visit(right)
1011        self.write_c(')')
1012
1013    def visit_UnaryOp(self, node):
1014        self.write_c('(')
1015        op = UNARYOP_SYMBOLS[type(node.op)]
1016        self.write_c(op)
1017        if op == 'not':
1018            self.write_c(' ')
1019        self.visit(node.operand)
1020        self.write_c(')')
1021
1022    def visit_Subscript(self, node):
1023        if node.value.id not in self.c_constants:
1024            if node.value.id not in self.c_pointers:
1025                self.c_pointers.append(node.value.id)
1026        self.in_subref = True
1027        self.visit(node.value)
1028        self.in_subref = False
1029        self.write_c('[')
1030        self.in_subscript = True
1031        self.visit(node.slice)
1032        self.in_subscript = False
1033        self.write_c(']')
1034
1035    def visit_Slice(self, node):
1036        if node.lower is not None:
1037            self.visit(node.lower)
1038        self.write_python(':')
1039        if node.upper is not None:
1040            self.visit(node.upper)
1041        if node.step is not None:
1042            self.write_python(':')
1043            if not(isinstance(node.step, Name) and node.step.id == 'None'):
1044                self.visit(node.step)
1045
1046    def visit_ExtSlice(self, node):
1047        for idx, item in node.dims:
1048            if idx:
1049                self.write_python(', ')
1050            self.visit(item)
1051
1052    def visit_Yield(self, node):
1053        self.unsupported(node)
1054
1055        self.write_python('yield ')
1056        self.visit(node.value)
1057
1058    def visit_Lambda(self, node):
1059        self.unsupported(node)
1060
1061        self.write_python('lambda ')
1062        self.visit(node.args)
1063        self.write_python(': ')
1064        self.visit(node.body)
1065
1066    def visit_Ellipsis(self, node):
1067        self.unsupported(node)
1068
1069        self.write_python('Ellipsis')
1070
1071    def generator_visit(left, right):
1072        def visit(self, node):
1073            self.write_python(left)
1074            self.write_c(left)
1075            self.visit(node.elt)
1076            for comprehension in node.generators:
1077                self.visit(comprehension)
1078            self.write_c(right)
1079            #self.write_python(right)
1080        return visit
1081
1082    visit_ListComp = generator_visit('[', ']')
1083    visit_GeneratorExp = generator_visit('(', ')')
1084    visit_SetComp = generator_visit('{', '}')
1085    del generator_visit
1086
1087    def visit_DictComp(self, node):
1088        self.unsupported(node)
1089
1090        self.write_python('{')
1091        self.visit(node.key)
1092        self.write_python(': ')
1093        self.visit(node.value)
1094        for comprehension in node.generators:
1095            self.visit(comprehension)
1096        self.write_python('}')
1097
1098    def visit_IfExp(self, node):
1099        self.write_c('((')
1100        self.visit(node.test)
1101        self.write_c(')?(')
1102        self.visit(node.body)
1103        self.write_c('):(')
1104        self.visit(node.orelse)
1105        self.write_c('))')
1106
1107    def visit_Starred(self, node):
1108        self.write_c('*')
1109        self.visit(node.value)
1110
1111    def visit_Repr(self, node):
1112        # CRUFT: python 2.6 only
1113        self.write_c('`')
1114        self.visit(node.value)
1115        self.write_python('`')
1116
1117    # Helper Nodes
1118
1119    def visit_alias(self, node):
1120        self.unsupported(node)
1121
1122        self.write_python(node.name)
1123        if node.asname is not None:
1124            self.write_python(' as ' + node.asname)
1125
1126    def visit_comprehension(self, node):
1127        self.write_c(' for ')
1128        self.visit(node.target)
1129        self.write_C(' in ')
1130        #self.write_python(' in ')
1131        self.visit(node.iter)
1132        if node.ifs:
1133            for if_ in node.ifs:
1134                self.write_python(' if ')
1135                self.visit(if_)
1136
1137    def visit_arguments(self, node):
1138        self.signature(node)
1139
1140    def unsupported(self, node, message=None):
1141        if hasattr(node, "value"):
1142            lineno = node.value.lineno
1143        elif hasattr(node, "iter"):
1144            lineno = node.iter.lineno
1145        else:
1146            #print(dir(node))
1147            lineno = 0
1148
1149        lineno += self.lineno_offset
1150        if self.fname:
1151            location = "%s(%d)" % (self.fname, lineno)
1152        else:
1153            location = "%d" % (self.fname, lineno)
1154        if self.current_function:
1155            location += ", function %s" % self.current_function
1156        if message is None:
1157            message = node.__class__.__name__ + " syntax not supported"
1158        raise SyntaxError("[%s] %s" % (location, message))
1159
1160def print_function(f=None):
1161    """
1162    Print out the code for the function
1163    """
1164    # Include some comments to see if they get printed
1165    import ast
1166    import inspect
1167    if f is not None:
1168        tree = ast.parse(inspect.getsource(f))
1169        tree_source = to_source(tree)
1170        print(tree_source)
1171
1172def define_constant(name, value, block_size=1):
1173    # type: (str, any, int) -> str
1174    """
1175    Convert a python constant into a C constant of the same name.
1176
1177    Returns the C declaration of the constant as a string, possibly containing
1178    line feeds.  The string will not be indented.
1179
1180    Supports int, double and sequences of double.
1181    """
1182    const = "constant "  # OpenCL needs globals to be constant
1183    if isinstance(value, int):
1184        parts = [const + "int ", name, " = ", "%d"%value, ";\n"]
1185    elif isinstance(value, float):
1186        parts = [const + "double ", name, " = ", "%.15g"%value, ";\n"]
1187    else:
1188        try:
1189            len(value)
1190        except TypeError:
1191            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1192        # extend constant arrays to a multiple of 4; not sure if this
1193        # is necessary, but some OpenCL targets broke if the number
1194        # of parameters in the parameter table was not a multiple of 4,
1195        # so do it for all constant arrays to be safe.
1196        if len(value)%block_size != 0:
1197            value = list(value) + [0.]*(block_size - len(value)%block_size)
1198        elements = ["%.15g"%v for v in value]
1199        parts = [const + "double ", name, "[]", " = ",
1200                 "{\n   ", ", ".join(elements), "\n};\n"]
1201
1202    return "".join(parts)
1203
1204
1205# Modified from the following:
1206#
1207#    http://code.activestate.com/recipes/578272-topological-sort/
1208#    Copyright (C) 2012 Sam Denton
1209#    License: MIT
1210def ordered_dag(dag):
1211    # type: (Dict[T, Set[T]]) -> Iterator[T]
1212    """
1213    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1214    in order such that every key occurs after the keys it depends upon.
1215
1216    This is an iterator not a sequence.  To reverse it use::
1217
1218        reversed(tuple(ordered_dag(dag)))
1219
1220    Raise an error if there are any cycles.
1221
1222    Keys are arbitrary hashable values.
1223    """
1224    # Local import to make the function stand-alone, and easier to borrow
1225    from functools import reduce
1226
1227    dag = dag.copy()
1228
1229    # make leaves depend on the empty set
1230    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1231    dag.update({node: set() for node in leaves})
1232    while True:
1233        leaves = set(node for node, links in dag.items() if not links)
1234        if not leaves:
1235            break
1236        for node in leaves:
1237            yield node
1238        dag = {node: (links-leaves)
1239               for node, links in dag.items() if node not in leaves}
1240    if dag:
1241        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1242                         % ", ".join(str(node) for node in dag.keys()))
1243
1244import re
1245PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n')
1246SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n'
1247PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n')
1248SUBST_STR = r'printf("\g<template>\n")'
1249def translate(functions, constants=None):
1250    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str]
1251    """
1252    Convert a list of functions to a list of C code strings.
1253
1254    A function is given by the tuple (source, filename, line number).
1255
1256    Global constants are given in a dictionary of {name: value}.  The
1257    constants are used for name space resolution and type inferencing.
1258    Constants are not translated by this code. Instead, call
1259    :func:`define_constant` with name and value, and maybe block_size
1260    if arrays need to be padded to the next block boundary.
1261
1262    Function prototypes are not generated. Use :func:`ordered_dag`
1263    to list the functions in reverse order of dependency before calling
1264    translate. [Maybe a future revision will return the function prototypes
1265    so that a suitable "*.h" file can be generated.
1266    """
1267    snippets = []
1268    #snippets.append("#include <math.h>")
1269    #snippets.append("")
1270    for source, fname, lineno in functions:
1271        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\'))
1272        snippets.append(line_directive)
1273        # Replace simple print function calls with printf statements
1274        source = PRINT_ARGS.sub(SUBST_ARGS, source)
1275        source = PRINT_STR.sub(SUBST_STR, source)
1276        tree = ast.parse(source)
1277        c_code = to_source(tree, constants=constants, fname=fname, lineno=lineno)
1278        snippets.append(c_code)
1279    return snippets
1280
1281def main():
1282    import os
1283    #print("Parsing...using Python" + sys.version)
1284    if len(sys.argv) == 1:
1285        print("""\
1286Usage: python py2c.py <infile> [<outfile>]
1287
1288if outfile is omitted, output file is '<infile>.c'
1289""")
1290        return
1291
1292    fname_in = sys.argv[1]
1293    if len(sys.argv) == 2:
1294        fname_base = os.path.splitext(fname_in)[0]
1295        fname_out = str(fname_base) + '.c'
1296    else:
1297        fname_out = sys.argv[2]
1298
1299    with open(fname_in, "r") as python_file:
1300        code = python_file.read()
1301    name = "gauss"
1302    code = (code
1303            .replace(name+'.n', 'GAUSS_N')
1304            .replace(name+'.z', 'GAUSS_Z')
1305            .replace(name+'.w', 'GAUSS_W')
1306            .replace('if __name__ == "__main__"', "def main()")
1307    )
1308
1309
1310    c_code = "".join(translate([(code, fname_in, 1)]))
1311    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])")
1312
1313    with open(fname_out, "w") as file_out:
1314        file_out.write("""
1315#include <stdio.h>
1316#include <stdbool.h>
1317#include <math.h>
1318#define constant const
1319double square(double x) { return x*x; }
1320double cube(double x) { return x*x*x; }
1321double polyval(constant double *coef, double x, int N)
1322{
1323    int i = 0;
1324    double ans = coef[0];
1325
1326    while (i < N) {
1327        ans = ans * x + coef[i++];
1328    }
1329
1330    return ans;
1331}
1332
1333""")
1334        file_out.write(c_code)
1335    #print("...Done")
1336
1337if __name__ == "__main__":
1338    main()
Note: See TracBrowser for help on using the repository browser.