source: sasmodels/sasmodels/py2c.py @ d5014e4

Last change on this file since d5014e4 was d5014e4, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

Merge remote-tracking branch 'omer/master'

  • Property mode set to 100644
File size: 44.9 KB
Line 
1r"""
2py2c
3~~~~
4
5Convert simple numeric python code into C code.
6
7This code is intended to translate direct algorithms for scientific code
8(mostly if statements and for loops operating on double precision values)
9into C code. Unlike projects like numba, cython, pypy and nuitka, the
10:func:`translate` function returns the corresponding C which can then be
11compiled with tinycc or sent to the GPU using CUDA or OpenCL.
12
13There is special handling certain constructs, such as *for i in range* and
14small integer powers.
15
16**TODO: make a nice list of supported constructs***
17
18Imports are not supported, but they are at least ignored so that properly
19constructed code can be run via python or translated to C without change.
20
21Most other python constructs are **not** supported:
22* classes
23* builtin types (dict, set, list)
24* exceptions
25* with context
26* del
27* yield
28* async
29* list slicing
30* multiple return values
31* "is/is not", "in/not in" conditionals
32
33There is limited support for list and list comprehensions, so long as they
34can be represented by a fixed array whose size is known at compile time, and
35they are small enough to be stored on the stack.
36
37Variables definition in C
38-------------------------
39Defining variables within the translate function is a bit of a guess work,
40using following rules:
41*   By default, a variable is a 'double'.
42*   Variable in a for loop is an int.
43*   Variable that is references with brackets is an array of doubles. The
44    variable within the brackets is integer. For example, in the
45    reference 'var1[var2]', var1 is a double array, and var2 is an integer.
46*   Assignment to an argument makes that argument an array, and the index
47    in that assignment is 0.
48    For example, the following python code::
49        def func(arg1, arg2):
50            arg2 = 17.
51    is translated to the following C code::
52        double func(double arg1)
53        {
54            arg2[0] = 17.0;
55        }
56    For example, the following python code is translated to the
57    following C code::
58
59        def func(arg1, arg2):          double func(double arg1) {
60            arg2 = 17.                      arg2[0] = 17.0;
61                                        }
62*   All functions are defined as double, even if there is no
63    return statement.
64
65Debugging
66---------
67
68*print* is partially supported using a simple regular expression. This
69requires a stylized form. Be sure to use print as a function instead of
70the print statement. If you are including substition variables, use the
71% string substitution style. Include parentheses around the substitution
72tuple, even if there is only one item; do not include the final comma even
73if it is a single item (yes, it won't be a tuple, but it makes the regexp
74much simpler). Keep the item on a single line. Here are three forms that work::
75
76    print("x") => printf("x\n");
77    print("x %g"%(a)) => printf("x %g\n", a);
78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c);
79
80You can generate *main* using the *if __name__ == "__main__":* construct.
81This does a simple substitution with "def main():" before translation and
82a substitution with "int main(int argc, double *argv[])" after translation.
83The result is that the content of the *if* block becomes the content of *main*.
84Along with the print statement, you can run and test a translation standalone
85using::
86
87    python py2c.py source.py
88    cc source.c
89    ./a.out
90
91Known issues
92------------
93The following constructs may cause problems:
94
95* implicit arrays: possible namespace collision for variable "vec#"
96* swap fails: "x,y = y,x" will set x==y
97* top-level statements: code outside a function body causes errors
98* line number skew: each statement should be tagged with its own #line
99  to avoid skew as comments are skipped and loop bodies are wrapped with
100  braces, etc.
101
102References
103----------
104
105Based on a variant of codegen.py:
106
107    https://github.com/andreif/codegen
108    :copyright: Copyright 2008 by Armin Ronacher.
109    :license: BSD.
110"""
111
112# Update Notes
113# ============
114# 2017-11-22, OE: Each 'visit_*' method is to build a C statement string. It
115#                 shold insert 4 blanks per indentation level. The 'body'
116#                 method will combine all the strings, by adding the
117#                 'current_statement' to the c_proc string list
118# 2017-11-22, OE: variables, argument definition implemented.  Note: An
119#                 argument is considered an array if it is the target of an
120#                 assignment. In that case it is translated to <var>[0]
121# 2017-11-27, OE: 'pow' basicly working
122# 2017-12-07, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
123# 2017-12-07, OE: Power function, including special cases of
124#                 square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
125#                 translate_power, called from visit_BinOp
126# 2017-12-07, OE: Translation of integer division, '\\' in python, implemented
127#                 in translate_integer_divide, called from visit_BinOp
128# 2017-12-07, OE: C variable definition handled in 'define_c_vars'
129#               : Python integer division, '//', translated to C in
130#                 'translate_integer_divide'
131# 2017-12-15, OE: Precedence maintained by writing opening and closing
132#                 parenthesesm '(',')', in procedure 'visit_BinOp'.
133# 2017-12-18, OE: Added call to 'add_current_line()' at the beginning
134#                 of visit_Return
135# 2018-01-03, PK: Update interface for use in sasmodels
136# 2018-01-03, PK: support "expr if cond else expr" syntax
137# 2018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y))
138# 2018-01-03, PK: True/False => true/false
139# 2018-01-03, PK: f(x) was introducing an extra semicolon
140# 2018-01-03, PK: simplistic print function, for debugging
141# 2018-01-03, PK: while expr: ... => while (expr) { ... }
142# 2018-01-04, OE: Fixed bug in 'visit_If': visiting node.orelse in case else exists.
143
144from __future__ import print_function
145
146import sys
147import ast
148from ast import NodeVisitor
149
150BINOP_SYMBOLS = {}
151BINOP_SYMBOLS[ast.Add] = '+'
152BINOP_SYMBOLS[ast.Sub] = '-'
153BINOP_SYMBOLS[ast.Mult] = '*'
154BINOP_SYMBOLS[ast.Div] = '/'
155BINOP_SYMBOLS[ast.Mod] = '%'
156BINOP_SYMBOLS[ast.Pow] = '**'
157BINOP_SYMBOLS[ast.LShift] = '<<'
158BINOP_SYMBOLS[ast.RShift] = '>>'
159BINOP_SYMBOLS[ast.BitOr] = '|'
160BINOP_SYMBOLS[ast.BitXor] = '^'
161BINOP_SYMBOLS[ast.BitAnd] = '&'
162BINOP_SYMBOLS[ast.FloorDiv] = '//'
163
164BOOLOP_SYMBOLS = {}
165BOOLOP_SYMBOLS[ast.And] = '&&'
166BOOLOP_SYMBOLS[ast.Or] = '||'
167
168CMPOP_SYMBOLS = {}
169CMPOP_SYMBOLS[ast.Eq] = '=='
170CMPOP_SYMBOLS[ast.NotEq] = '!='
171CMPOP_SYMBOLS[ast.Lt] = '<'
172CMPOP_SYMBOLS[ast.LtE] = '<='
173CMPOP_SYMBOLS[ast.Gt] = '>'
174CMPOP_SYMBOLS[ast.GtE] = '>='
175CMPOP_SYMBOLS[ast.Is] = 'is'
176CMPOP_SYMBOLS[ast.IsNot] = 'is not'
177CMPOP_SYMBOLS[ast.In] = 'in'
178CMPOP_SYMBOLS[ast.NotIn] = 'not in'
179
180UNARYOP_SYMBOLS = {}
181UNARYOP_SYMBOLS[ast.Invert] = '~'
182UNARYOP_SYMBOLS[ast.Not] = 'not'
183UNARYOP_SYMBOLS[ast.UAdd] = '+'
184UNARYOP_SYMBOLS[ast.USub] = '-'
185
186
187# TODO: should not allow eval of arbitrary python
188def isevaluable(s):
189    try:
190        eval(s)
191        return True
192    except Exception:
193        return False
194
195class SourceGenerator(NodeVisitor):
196    """This visitor is able to transform a well formed syntax tree into python
197    sourcecode.  For more details have a look at the docstring of the
198    `node_to_source` function.
199    """
200
201    def __init__(self, indent_with="    ", add_line_information=False,
202                 constants=None, fname=None, lineno=0):
203        self.result = []
204        self.indent_with = indent_with
205        self.add_line_information = add_line_information
206        self.indentation = 0
207        self.new_lines = 0
208
209        # for C
210        self.c_proc = []
211        self.signature_line = 0
212        self.arguments = []
213        self.current_function = ""
214        self.fname = fname
215        self.lineno_offset = lineno
216        self.warnings = []
217        self.statements = []
218        self.current_statement = ""
219        # TODO: use set rather than list for c_vars, ...
220        self.c_vars = []
221        self.c_int_vars = []
222        self.c_pointers = []
223        self.c_dcl_pointers = []
224        self.c_functions = []
225        self.c_vectors = []
226        self.c_constants = constants if constants is not None else {}
227        self.in_expr = False
228        self.in_subref = False
229        self.in_subscript = False
230        self.tuples = []
231        self.required_functions = []
232        self.is_sequence = False
233        self.visited_args = False
234
235    def write_python(self, x):
236        if self.new_lines:
237            if self.result:
238                self.result.append('\n' * self.new_lines)
239            self.result.append(self.indent_with * self.indentation)
240            self.new_lines = 0
241        self.result.append(x)
242
243    def write_c(self, statement):
244        # TODO: build up as a list rather than adding to string
245        self.current_statement += statement
246
247    def add_c_line(self, line):
248        indentation = self.indent_with * self.indentation
249        self.c_proc.append("".join((indentation, line, "\n")))
250
251    def add_current_line(self):
252        if self.current_statement:
253            self.add_c_line(self.current_statement)
254            self.current_statement = ''
255
256    def add_unique_var(self, new_var):
257        if new_var not in self.c_vars:
258            self.c_vars.append(str(new_var))
259
260    def write_sincos(self, node):
261        angle = str(node.args[0].id)
262        self.write_c(node.args[1].id + " = sin(" + angle + ");")
263        self.add_current_line()
264        self.write_c(node.args[2].id + " = cos(" + angle + ");")
265        self.add_current_line()
266        for arg in node.args:
267            self.add_unique_var(arg.id)
268
269    def newline(self, node=None, extra=0):
270        self.new_lines = max(self.new_lines, 1 + extra)
271        if node is not None and self.add_line_information:
272            self.write_c('// line: %s' % node.lineno)
273            self.new_lines = 1
274        if self.current_statement:
275            self.statements.append(self.current_statement)
276            self.current_statement = ''
277
278    def body(self, statements):
279        if self.current_statement:
280            self.add_current_line()
281        self.new_line = True
282        self.indentation += 1
283        for stmt in statements:
284            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
285            #    target_name = stmt.targets[0].id # target name needed for debug only
286            self.visit(stmt)
287        self.add_current_line() # just for breaking point. to be deleted.
288        self.indentation -= 1
289
290    def body_or_else(self, node):
291        self.body(node.body)
292        if node.orelse:
293            self.unsupported(node, "for...else/while...else not supported")
294
295            self.newline()
296            self.write_c('else:')
297            self.body(node.orelse)
298
299    def signature(self, node):
300        want_comma = []
301        def write_comma():
302            if want_comma:
303                self.write_c(', ')
304            else:
305                want_comma.append(True)
306
307        # for C
308        for arg in node.args:
309            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
310            try:
311                arg_name = arg.arg
312            except AttributeError:
313                arg_name = arg.id
314            self.arguments.append(arg_name)
315
316        padding = [None] *(len(node.args) - len(node.defaults))
317        for arg, default in zip(node.args, padding + node.defaults):
318            if default is not None:
319                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
320                try:
321                    arg_name = arg.arg
322                except AttributeError:
323                    arg_name = arg.id
324                w_str = ("C does not support default parameters: %s=%s"
325                         % (arg_name, str(default.n)))
326                self.warnings.append(w_str)
327
328    def decorators(self, node):
329        for decorator in node.decorator_list:
330            self.newline(decorator)
331            self.write_python('@')
332            self.visit(decorator)
333
334    # Statements
335
336    def visit_Assert(self, node):
337        self.newline(node)
338        self.write_c('assert ')
339        self.visit(node.test)
340        if node.msg is not None:
341            self.write_python(', ')
342            self.visit(node.msg)
343
344    def define_c_vars(self, target):
345        if hasattr(target, 'id'):
346        # a variable is considered an array if it apears in the agrument list
347        # and being assigned to. For example, the variable p in the following
348        # sniplet is a pointer, while q is not
349        # def somefunc(p, q):
350        #  p = q + 1
351        #  return
352        #
353            if target.id not in self.c_vars:
354                if target.id in self.arguments:
355                    idx = self.arguments.index(target.id)
356                    new_target = self.arguments[idx] + "[0]"
357                    if new_target not in self.c_pointers:
358                        target.id = new_target
359                        self.c_pointers.append(self.arguments[idx])
360                else:
361                    self.c_vars.append(target.id)
362
363    def add_semi_colon(self):
364        #semi_pos = self.current_statement.find(';')
365        #if semi_pos >= 0:
366        #    self.current_statement = self.current_statement.replace(';', '')
367        self.write_c(';')
368
369    def visit_Assign(self, node):
370        self.add_current_line()
371        self.in_expr = True
372        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
373            if idx:
374                self.write_c(' = ')
375            self.define_c_vars(target)
376            self.visit(target)
377        # Capture assigned tuple names, if any
378        targets = self.tuples[:]
379        del self.tuples[:]
380        self.write_c(' = ')
381        self.is_sequence = False
382        self.visited_args = False
383        self.visit(node.value)
384        self.add_semi_colon()
385        self.add_current_line()
386        # Assign tuples to tuples, if any
387        # TODO: doesn't handle swap:  a,b = b,a
388        for target, item in zip(targets, self.tuples):
389            self.visit(target)
390            self.write_c(' = ')
391            self.visit(item)
392            self.add_semi_colon()
393            self.add_current_line()
394        if self.is_sequence and not self.visited_args:
395            for target in node.targets:
396                if hasattr(target, 'id'):
397                    if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
398                        if target.id not in self.c_dcl_pointers:
399                            self.c_dcl_pointers.append(target.id)
400                            if target.id in self.c_vars:
401                                self.c_vars.remove(target.id)
402        self.current_statement = ''
403        self.in_expr = False
404
405    def visit_AugAssign(self, node):
406        if node.target.id not in self.c_vars:
407            if node.target.id not in self.arguments:
408                self.c_vars.append(node.target.id)
409        self.in_expr = True
410        self.visit(node.target)
411        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
412        self.visit(node.value)
413        self.add_semi_colon()
414        self.in_expr = False
415        self.add_current_line()
416
417    def visit_ImportFrom(self, node):
418        return  # import ignored
419        self.newline(node)
420        self.write_python('from %s%s import ' %('.' * node.level, node.module))
421        for idx, item in enumerate(node.names):
422            if idx:
423                self.write_python(', ')
424            self.write_python(item)
425
426    def visit_Import(self, node):
427        return  # import ignored
428        self.newline(node)
429        for item in node.names:
430            self.write_python('import ')
431            self.visit(item)
432
433    def visit_Expr(self, node):
434        #self.in_expr = True
435        self.newline(node)
436        self.generic_visit(node)
437        #self.in_expr = False
438
439    def write_c_pointers(self, start_var):
440        if self.c_dcl_pointers:
441            var_list = []
442            for c_ptr in self.c_dcl_pointers:
443                if c_ptr not in self.arguments:
444                    var_list.append("*" + c_ptr)
445                if c_ptr in self.c_vars:
446                    self.c_vars.remove(c_ptr)
447            if var_list:
448                c_dcl = "    double " + ", ".join(var_list) + ";\n"
449                self.c_proc.insert(start_var, c_dcl)
450                start_var += 1
451        return start_var
452
453    def insert_c_vars(self, start_var):
454        have_decls = False
455        start_var = self.write_c_pointers(start_var)
456        if self.c_int_vars:
457            for var in self.c_int_vars:
458                if var in self.c_vars:
459                    self.c_vars.remove(var)
460            decls = ", ".join(self.c_int_vars)
461            self.c_proc.insert(start_var, "    int " + decls + ";\n")
462            have_decls = True
463            start_var += 1
464
465        if self.c_vars:
466            decls = ", ".join(self.c_vars)
467            self.c_proc.insert(start_var, "    double " + decls + ";\n")
468            have_decls = True
469            start_var += 1
470
471        if self.c_vectors:
472            for vec_number, vec_value  in enumerate(self.c_vectors):
473                name = "vec" + str(vec_number + 1)
474                decl = "    double " + name + "[] = {" + vec_value + "};"
475                self.c_proc.insert(start_var, decl + "\n")
476                start_var += 1
477
478        del self.c_vars[:]
479        del self.c_int_vars[:]
480        del self.c_vectors[:]
481        del self.c_pointers[:]
482        del self.c_dcl_pointers[:]
483        if have_decls:
484            self.c_proc.insert(start_var, "\n")
485
486    def insert_signature(self):
487        arg_decls = []
488        for arg in self.arguments:
489            decl = "double " + arg
490            if arg in self.c_pointers:
491                decl += "[]"
492            arg_decls.append(decl)
493        args_str = ", ".join(arg_decls)
494        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
495        if self.signature_line >= 0:
496            self.c_proc.insert(self.signature_line, method_sig)
497
498    def visit_FunctionDef(self, node):
499        if self.current_function:
500            self.unsupported(node, "function within a function")
501        self.current_function = node.name
502
503        # remember the location of the next warning that will be inserted
504        # so that we can stuff the function name ahead of the warning list
505        # if any warnings are generated by the function.
506        warning_index = len(self.warnings)
507
508        self.newline(extra=1)
509        self.decorators(node)
510        self.newline(node)
511        self.arguments = []
512        self.visit(node.args)
513        # for C
514        self.signature_line = len(self.c_proc)
515        self.add_c_line("\n{")
516        start_vars = len(self.c_proc) + 1
517        self.body(node.body)
518        self.add_c_line("}\n")
519        self.insert_signature()
520        self.insert_c_vars(start_vars)
521        del self.c_pointers[:]
522        self.current_function = ""
523        if warning_index != len(self.warnings):
524            self.warnings.insert(warning_index, "Warning in function '" + node.name + "':")
525
526    def visit_ClassDef(self, node):
527        have_args = []
528        def paren_or_comma():
529            if have_args:
530                self.write_python(', ')
531            else:
532                have_args.append(True)
533                self.write_python('(')
534
535        self.newline(extra=2)
536        self.decorators(node)
537        self.newline(node)
538        self.write_python('class %s' % node.name)
539        for base in node.bases:
540            paren_or_comma()
541            self.visit(base)
542        # CRUFT: python 2.6 does not have "keywords" attribute
543        if hasattr(node, 'keywords'):
544            for keyword in node.keywords:
545                paren_or_comma()
546                self.write_python(keyword.arg + '=')
547                self.visit(keyword.value)
548            if node.starargs is not None:
549                paren_or_comma()
550                self.write_python('*')
551                self.visit(node.starargs)
552            if node.kwargs is not None:
553                paren_or_comma()
554                self.write_python('**')
555                self.visit(node.kwargs)
556        self.write_python(have_args and '):' or ':')
557        self.body(node.body)
558
559    def visit_If(self, node):
560
561        self.write_c('if ')
562        self.in_expr = True
563        self.visit(node.test)
564        self.in_expr = False
565        self.write_c(' {')
566        self.body(node.body)
567        self.add_c_line('}')
568        while True:
569            else_ = node.orelse
570            if len(else_) == 0:
571                break
572            #elif hasattr(else_, 'orelse'):
573            elif len(else_) == 1 and isinstance(else_[0], ast.If):
574                node = else_[0]
575                #self.newline()
576                self.write_c('else if ')
577                self.in_expr = True
578                self.visit(node.test)
579                self.in_expr = False
580                self.write_c(' {')
581                self.body(node.body)
582                self.add_current_line()
583                self.add_c_line('}')
584                #break
585            else:
586                self.newline()
587                self.write_c('else {')
588                self.body(else_)
589                self.add_c_line('}')
590                break
591
592    def get_for_range(self, node):
593        stop = ""
594        start = '0'
595        step = '1'
596        for_args = []
597        temp_statement = self.current_statement
598        self.current_statement = ''
599        for arg in node.iter.args:
600            self.visit(arg)
601            for_args.append(self.current_statement)
602            self.current_statement = ''
603        self.current_statement = temp_statement
604        if len(for_args) == 1:
605            stop = for_args[0]
606        elif len(for_args) == 2:
607            start = for_args[0]
608            stop = for_args[1]
609        elif len(for_args) == 3:
610            start = for_args[0]
611            stop = for_args[1]
612            start = for_args[2]
613        else:
614            raise("Ilegal for loop parameters")
615        return start, stop, step
616
617    def add_c_int_var(self, name):
618        if name not in self.c_int_vars:
619            self.c_int_vars.append(name)
620
621    def visit_For(self, node):
622        # node: for iterator is stored in node.target.
623        # Iterator name is in node.target.id.
624        self.add_current_line()
625        fForDone = False
626        self.current_statement = ''
627        if hasattr(node.iter, 'func'):
628            if hasattr(node.iter.func, 'id'):
629                if node.iter.func.id == 'range':
630                    self.visit(node.target)
631                    iterator = self.current_statement
632                    self.current_statement = ''
633                    self.add_c_int_var(iterator)
634                    start, stop, step = self.get_for_range(node)
635                    self.write_c("for (" + iterator + "=" + str(start) +
636                                 " ; " + iterator + " < " + str(stop) +
637                                 " ; " + iterator + " += " + str(step) + ") {")
638                    self.body_or_else(node)
639                    self.write_c("}")
640                    fForDone = True
641        if not fForDone:
642            # Generate the statement that is causing the error
643            self.current_statement = ''
644            self.write_c('for ')
645            self.visit(node.target)
646            self.write_c(' in ')
647            self.visit(node.iter)
648            self.write_c(':')
649            # report the error
650            self.unsupported(node, "unsupported " + self.current_statement)
651
652    def visit_While(self, node):
653        self.newline(node)
654        self.write_c('while ')
655        self.visit(node.test)
656        self.write_c(' {')
657        self.body_or_else(node)
658        self.write_c('}')
659        self.add_current_line()
660
661    def visit_With(self, node):
662        self.unsupported(node)
663
664        self.newline(node)
665        self.write_python('with ')
666        self.visit(node.context_expr)
667        if node.optional_vars is not None:
668            self.write_python(' as ')
669            self.visit(node.optional_vars)
670        self.write_python(':')
671        self.body(node.body)
672
673    def visit_Pass(self, node):
674        self.newline(node)
675        #self.write_python('pass')
676
677    def visit_Print(self, node):
678        self.unsupported(node)
679
680        # CRUFT: python 2.6 only
681        self.newline(node)
682        self.write_c('print ')
683        want_comma = False
684        if node.dest is not None:
685            self.write_c(' >> ')
686            self.visit(node.dest)
687            want_comma = True
688        for value in node.values:
689            if want_comma:
690                self.write_c(', ')
691            self.visit(value)
692            want_comma = True
693        if not node.nl:
694            self.write_c(',')
695
696    def visit_Delete(self, node):
697        self.unsupported(node)
698
699        self.newline(node)
700        self.write_python('del ')
701        for idx, target in enumerate(node):
702            if idx:
703                self.write_python(', ')
704            self.visit(target)
705
706    def visit_TryExcept(self, node):
707        self.unsupported(node)
708
709        self.newline(node)
710        self.write_python('try:')
711        self.body(node.body)
712        for handler in node.handlers:
713            self.visit(handler)
714
715    def visit_TryFinally(self, node):
716        self.unsupported(node)
717
718        self.newline(node)
719        self.write_python('try:')
720        self.body(node.body)
721        self.newline(node)
722        self.write_python('finally:')
723        self.body(node.finalbody)
724
725    def visit_Global(self, node):
726        self.unsupported(node)
727
728        self.newline(node)
729        self.write_python('global ' + ', '.join(node.names))
730
731    def visit_Nonlocal(self, node):
732        self.newline(node)
733        self.write_python('nonlocal ' + ', '.join(node.names))
734
735    def visit_Return(self, node):
736        self.add_current_line()
737        self.in_expr = True
738        if node.value is None:
739            self.write_c('return')
740        else:
741            self.write_c('return ')
742            self.visit(node.value)
743        self.add_semi_colon()
744        self.in_expr = False
745        self.add_c_line(self.current_statement)
746        self.current_statement = ''
747
748    def visit_Break(self, node):
749        self.newline(node)
750        self.write_c('break')
751
752    def visit_Continue(self, node):
753        self.newline(node)
754        self.write_c('continue')
755
756    def visit_Raise(self, node):
757        self.unsupported(node)
758
759        # CRUFT: Python 2.6 / 3.0 compatibility
760        self.newline(node)
761        self.write_python('raise')
762        if hasattr(node, 'exc') and node.exc is not None:
763            self.write_python(' ')
764            self.visit(node.exc)
765            if node.cause is not None:
766                self.write_python(' from ')
767                self.visit(node.cause)
768        elif hasattr(node, 'type') and node.type is not None:
769            self.visit(node.type)
770            if node.inst is not None:
771                self.write_python(', ')
772                self.visit(node.inst)
773            if node.tback is not None:
774                self.write_python(', ')
775                self.visit(node.tback)
776
777    # Expressions
778
779    def visit_Attribute(self, node):
780        self.unsupported(node, "attribute reference a.b not supported")
781
782        self.visit(node.value)
783        self.write_python('.' + node.attr)
784
785    def visit_Call(self, node):
786        want_comma = []
787        def write_comma():
788            if want_comma:
789                self.write_c(', ')
790            else:
791                want_comma.append(True)
792        if hasattr(node.func, 'id'):
793            if node.func.id not in self.c_functions:
794                self.c_functions.append(node.func.id)
795            if node.func.id == 'abs':
796                self.write_c("fabs ")
797            elif node.func.id == 'int':
798                self.write_c('(int) ')
799            elif node.func.id == "SINCOS":
800                self.write_sincos(node)
801                return
802            else:
803                self.visit(node.func)
804        else:
805            self.visit(node.func)
806        self.write_c('(')
807        for arg in node.args:
808            write_comma()
809            self.visited_args = True
810            self.visit(arg)
811        for keyword in node.keywords:
812            write_comma()
813            self.write_c(keyword.arg + '=')
814            self.visit(keyword.value)
815        if hasattr(node, 'starargs'):
816            if node.starargs is not None:
817                write_comma()
818                self.write_c('*')
819                self.visit(node.starargs)
820        if hasattr(node, 'kwargs'):
821            if node.kwargs is not None:
822                write_comma()
823                self.write_c('**')
824                self.visit(node.kwargs)
825        self.write_c(')')
826        if not self.in_expr:
827            self.add_semi_colon()
828
829    TRANSLATE_CONSTANTS = {
830        # python 2 uses normal name references through vist_Name
831        'True': 'true',
832        'False': 'false',
833        'None': 'NULL',  # "None" will probably fail for other reasons
834        # python 3 uses NameConstant
835        True: 'true',
836        False: 'false',
837        None: 'NULL',  # "None" will probably fail for other reasons
838        }
839
840    def visit_Name(self, node):
841        translation = self.TRANSLATE_CONSTANTS.get(node.id, None)
842        if translation:
843            self.write_c(translation)
844            return
845        self.write_c(node.id)
846        if node.id in self.c_pointers and not self.in_subref:
847            self.write_c("[0]")
848        name = ""
849        sub = node.id.find("[")
850        if sub > 0:
851            name = node.id[0:sub].strip()
852        else:
853            name = node.id
854        # add variable to c_vars if it ins't there yet, not an argument and not a number
855        if (name not in self.c_functions and name not in self.c_vars and
856                name not in self.c_int_vars and name not in self.arguments and
857                name not in self.c_constants and not name.isdigit()):
858            if self.in_subscript:
859                self.add_c_int_var(node.id)
860            else:
861                self.c_vars.append(node.id)
862
863    def visit_NameConstant(self, node):
864        translation = self.TRANSLATE_CONSTANTS.get(node.value, None)
865        if translation is not None:
866            self.write_c(translation)
867        else:
868            self.unsupported(node, "don't know how to translate %r"%node.value)
869
870    def visit_Str(self, node):
871        s = node.s
872        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
873        self.write_c('"')
874        self.write_c(s)
875        self.write_c('"')
876
877    def visit_Bytes(self, node):
878        s = node.s
879        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
880        self.write_c('"')
881        self.write_c(s)
882        self.write_c('"')
883
884    def visit_Num(self, node):
885        self.write_c(repr(node.n))
886
887    def visit_Tuple(self, node):
888        for idx, item in enumerate(node.elts):
889            if idx:
890                self.tuples.append(item)
891            else:
892                self.visit(item)
893
894    def sequence_visit(left, right):
895        def visit(self, node):
896            self.is_sequence = True
897            s = ""
898            for idx, item in enumerate(node.elts):
899                if idx > 0 and s:
900                    s += ', '
901                if hasattr(item, 'id'):
902                    s += item.id
903                elif hasattr(item, 'n'):
904                    s += str(item.n)
905            if s:
906                self.c_vectors.append(s)
907                vec_name = "vec"  + str(len(self.c_vectors))
908                self.write_c(vec_name)
909        return visit
910
911    visit_List = sequence_visit('[', ']')
912    visit_Set = sequence_visit('{', '}')
913    del sequence_visit
914
915    def visit_Dict(self, node):
916        self.unsupported(node)
917
918        self.write_python('{')
919        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
920            if idx:
921                self.write_python(', ')
922            self.visit(key)
923            self.write_python(': ')
924            self.visit(value)
925        self.write_python('}')
926
927    def get_special_power(self, string):
928        function_name = ''
929        is_negative_exp = False
930        if isevaluable(str(self.current_statement)):
931            exponent = eval(string)
932            is_negative_exp = exponent < 0
933            abs_exponent = abs(exponent)
934            if abs_exponent == 2:
935                function_name = "square"
936            elif abs_exponent == 3:
937                function_name = "cube"
938            elif abs_exponent == 0.5:
939                function_name = "sqrt"
940            elif abs_exponent == 1.0/3.0:
941                function_name = "cbrt"
942        if function_name == '':
943            function_name = "pow"
944        return function_name, is_negative_exp
945
946    def translate_power(self, node):
947        # get exponent by visiting the right hand argument.
948        function_name = "pow"
949        temp_statement = self.current_statement
950        # 'visit' functions write the results to the 'current_statement' class memnber
951        # Here, a temporary variable, 'temp_statement', is used, that enables the
952        # use of the 'visit' function
953        self.current_statement = ''
954        self.visit(node.right)
955        exponent = self.current_statement.replace(' ', '')
956        function_name, is_negative_exp = self.get_special_power(self.current_statement)
957        self.current_statement = temp_statement
958        if is_negative_exp:
959            self.write_c("1.0 /(")
960        self.write_c(function_name + "(")
961        self.visit(node.left)
962        if function_name == "pow":
963            self.write_c(", ")
964            self.visit(node.right)
965        self.write_c(")")
966        if is_negative_exp:
967            self.write_c(")")
968        #self.write_c(" ")
969
970    def translate_integer_divide(self, node):
971        self.write_c("(int)((")
972        self.visit(node.left)
973        self.write_c(")/(")
974        self.visit(node.right)
975        self.write_c("))")
976
977    def translate_float_divide(self, node):
978        self.write_c("((double)(")
979        self.visit(node.left)
980        self.write_c(")/(double)(")
981        self.visit(node.right)
982        self.write_c("))")
983
984    def visit_BinOp(self, node):
985        self.write_c("(")
986        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
987            self.translate_power(node)
988        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
989            self.translate_integer_divide(node)
990        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]:
991            self.translate_float_divide(node)
992        else:
993            self.visit(node.left)
994            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
995            self.visit(node.right)
996        self.write_c(")")
997
998    # for C
999    def visit_BoolOp(self, node):
1000        self.write_c('(')
1001        for idx, value in enumerate(node.values):
1002            if idx:
1003                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
1004            self.visit(value)
1005        self.write_c(')')
1006
1007    def visit_Compare(self, node):
1008        self.write_c('(')
1009        self.visit(node.left)
1010        for op, right in zip(node.ops, node.comparators):
1011            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1012            self.visit(right)
1013        self.write_c(')')
1014
1015    def visit_UnaryOp(self, node):
1016        self.write_c('(')
1017        op = UNARYOP_SYMBOLS[type(node.op)]
1018        self.write_c(op)
1019        if op == 'not':
1020            self.write_c(' ')
1021        self.visit(node.operand)
1022        self.write_c(')')
1023
1024    def visit_Subscript(self, node):
1025        if node.value.id not in self.c_constants:
1026            if node.value.id not in self.c_pointers:
1027                self.c_pointers.append(node.value.id)
1028        self.in_subref = True
1029        self.visit(node.value)
1030        self.in_subref = False
1031        self.write_c('[')
1032        self.in_subscript = True
1033        self.visit(node.slice)
1034        self.in_subscript = False
1035        self.write_c(']')
1036
1037    def visit_Slice(self, node):
1038        if node.lower is not None:
1039            self.visit(node.lower)
1040        self.write_python(':')
1041        if node.upper is not None:
1042            self.visit(node.upper)
1043        if node.step is not None:
1044            self.write_python(':')
1045            if not(isinstance(node.step, Name) and node.step.id == 'None'):
1046                self.visit(node.step)
1047
1048    def visit_ExtSlice(self, node):
1049        for idx, item in node.dims:
1050            if idx:
1051                self.write_python(', ')
1052            self.visit(item)
1053
1054    def visit_Yield(self, node):
1055        self.unsupported(node)
1056
1057        self.write_python('yield ')
1058        self.visit(node.value)
1059
1060    def visit_Lambda(self, node):
1061        self.unsupported(node)
1062
1063        self.write_python('lambda ')
1064        self.visit(node.args)
1065        self.write_python(': ')
1066        self.visit(node.body)
1067
1068    def visit_Ellipsis(self, node):
1069        self.unsupported(node)
1070
1071        self.write_python('Ellipsis')
1072
1073    def generator_visit(left, right):
1074        def visit(self, node):
1075            self.write_python(left)
1076            self.write_c(left)
1077            self.visit(node.elt)
1078            for comprehension in node.generators:
1079                self.visit(comprehension)
1080            self.write_c(right)
1081            #self.write_python(right)
1082        return visit
1083
1084    visit_ListComp = generator_visit('[', ']')
1085    visit_GeneratorExp = generator_visit('(', ')')
1086    visit_SetComp = generator_visit('{', '}')
1087    del generator_visit
1088
1089    def visit_DictComp(self, node):
1090        self.unsupported(node)
1091
1092        self.write_python('{')
1093        self.visit(node.key)
1094        self.write_python(': ')
1095        self.visit(node.value)
1096        for comprehension in node.generators:
1097            self.visit(comprehension)
1098        self.write_python('}')
1099
1100    def visit_IfExp(self, node):
1101        self.write_c('((')
1102        self.visit(node.test)
1103        self.write_c(')?(')
1104        self.visit(node.body)
1105        self.write_c('):(')
1106        self.visit(node.orelse)
1107        self.write_c('))')
1108
1109    def visit_Starred(self, node):
1110        self.write_c('*')
1111        self.visit(node.value)
1112
1113    def visit_Repr(self, node):
1114        # CRUFT: python 2.6 only
1115        self.write_c('`')
1116        self.visit(node.value)
1117        self.write_python('`')
1118
1119    # Helper Nodes
1120
1121    def visit_alias(self, node):
1122        self.unsupported(node)
1123
1124        self.write_python(node.name)
1125        if node.asname is not None:
1126            self.write_python(' as ' + node.asname)
1127
1128    def visit_comprehension(self, node):
1129        self.write_c(' for ')
1130        self.visit(node.target)
1131        self.write_C(' in ')
1132        #self.write_python(' in ')
1133        self.visit(node.iter)
1134        if node.ifs:
1135            for if_ in node.ifs:
1136                self.write_python(' if ')
1137                self.visit(if_)
1138
1139    def visit_arguments(self, node):
1140        self.signature(node)
1141
1142    def unsupported(self, node, message=None):
1143        if hasattr(node, "value"):
1144            lineno = node.value.lineno
1145        elif hasattr(node, "iter"):
1146            lineno = node.iter.lineno
1147        else:
1148            #print(dir(node))
1149            lineno = 0
1150
1151        lineno += self.lineno_offset
1152        if self.fname:
1153            location = "%s(%d)" % (self.fname, lineno)
1154        else:
1155            location = "%d" % (self.fname, lineno)
1156        if self.current_function:
1157            location += ", function %s" % self.current_function
1158        if message is None:
1159            message = node.__class__.__name__ + " syntax not supported"
1160        raise SyntaxError("[%s] %s" % (location, message))
1161
1162def print_function(f=None):
1163    """
1164    Print out the code for the function
1165    """
1166    # Include some comments to see if they get printed
1167    import ast
1168    import inspect
1169    if f is not None:
1170        tree = ast.parse(inspect.getsource(f))
1171        tree_source = to_source(tree)
1172        print(tree_source)
1173
1174def define_constant(name, value, block_size=1):
1175    # type: (str, any, int) -> str
1176    """
1177    Convert a python constant into a C constant of the same name.
1178
1179    Returns the C declaration of the constant as a string, possibly containing
1180    line feeds.  The string will not be indented.
1181
1182    Supports int, double and sequences of double.
1183    """
1184    const = "constant "  # OpenCL needs globals to be constant
1185    if isinstance(value, int):
1186        parts = [const + "int ", name, " = ", "%d"%value, ";\n"]
1187    elif isinstance(value, float):
1188        parts = [const + "double ", name, " = ", "%.15g"%value, ";\n"]
1189    else:
1190        try:
1191            len(value)
1192        except TypeError:
1193            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1194        # extend constant arrays to a multiple of 4; not sure if this
1195        # is necessary, but some OpenCL targets broke if the number
1196        # of parameters in the parameter table was not a multiple of 4,
1197        # so do it for all constant arrays to be safe.
1198        if len(value)%block_size != 0:
1199            value = list(value) + [0.]*(block_size - len(value)%block_size)
1200        elements = ["%.15g"%v for v in value]
1201        parts = [const + "double ", name, "[]", " = ",
1202                 "{\n   ", ", ".join(elements), "\n};\n"]
1203
1204    return "".join(parts)
1205
1206
1207# Modified from the following:
1208#
1209#    http://code.activestate.com/recipes/578272-topological-sort/
1210#    Copyright (C) 2012 Sam Denton
1211#    License: MIT
1212def ordered_dag(dag):
1213    # type: (Dict[T, Set[T]]) -> Iterator[T]
1214    """
1215    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1216    in order such that every key occurs after the keys it depends upon.
1217
1218    This is an iterator not a sequence.  To reverse it use::
1219
1220        reversed(tuple(ordered_dag(dag)))
1221
1222    Raise an error if there are any cycles.
1223
1224    Keys are arbitrary hashable values.
1225    """
1226    # Local import to make the function stand-alone, and easier to borrow
1227    from functools import reduce
1228
1229    dag = dag.copy()
1230
1231    # make leaves depend on the empty set
1232    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1233    dag.update({node: set() for node in leaves})
1234    while True:
1235        leaves = set(node for node, links in dag.items() if not links)
1236        if not leaves:
1237            break
1238        for node in leaves:
1239            yield node
1240        dag = {node: (links-leaves)
1241               for node, links in dag.items() if node not in leaves}
1242    if dag:
1243        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1244                         % ", ".join(str(node) for node in dag.keys()))
1245
1246import re
1247PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n')
1248SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n'
1249PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n')
1250SUBST_STR = r'printf("\g<template>\n")'
1251def translate(functions, constants=None):
1252    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str]
1253    """
1254    Convert a list of functions to a list of C code strings.
1255
1256    Returns list of corresponding code snippets (with trailing lines in
1257    each block) and a list of warnings generated by the translator.
1258
1259    A function is given by the tuple (source, filename, line number).
1260
1261    Global constants are given in a dictionary of {name: value}.  The
1262    constants are used for name space resolution and type inferencing.
1263    Constants are not translated by this code. Instead, call
1264    :func:`define_constant` with name and value, and maybe block_size
1265    if arrays need to be padded to the next block boundary.
1266
1267    Function prototypes are not generated. Use :func:`ordered_dag`
1268    to list the functions in reverse order of dependency before calling
1269    translate. [Maybe a future revision will return the function prototypes
1270    so that a suitable "*.h" file can be generated.
1271    """
1272    snippets = []
1273    warnings = []
1274    for source, fname, lineno in functions:
1275        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\'))
1276        snippets.append(line_directive)
1277        # Replace simple print function calls with printf statements
1278        source = PRINT_ARGS.sub(SUBST_ARGS, source)
1279        source = PRINT_STR.sub(SUBST_STR, source)
1280        tree = ast.parse(source)
1281        generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
1282        generator.visit(tree)
1283        c_code = "".join(generator.c_proc)
1284        snippets.append(c_code)
1285        warnings.extend(generator.warnings)
1286    return snippets, warnings
1287
1288def to_source(tree, constants=None, fname=None, lineno=0):
1289    """
1290    This function can convert a syntax tree into C sourcecode.
1291    """
1292    c_code = "".join(generator.c_proc)
1293    return c_code
1294
1295
1296C_HEADER = """
1297#include <stdio.h>
1298#include <stdbool.h>
1299#include <math.h>
1300#define constant const
1301double square(double x) { return x*x; }
1302double cube(double x) { return x*x*x; }
1303double polyval(constant double *coef, double x, int N)
1304{
1305    int i = 0;
1306    double ans = coef[0];
1307
1308    while (i < N) {
1309        ans = ans * x + coef[i++];
1310    }
1311
1312    return ans;
1313}
1314"""
1315
1316USAGE = """\
1317Usage: python py2c.py <infile> [<outfile>]
1318
1319if outfile is omitted, output file is '<infile>.c'
1320"""
1321
1322def main():
1323    import os
1324    #print("Parsing...using Python" + sys.version)
1325    if len(sys.argv) == 1:
1326        print(USAGE)
1327        return
1328
1329    fname_in = sys.argv[1]
1330    if len(sys.argv) == 2:
1331        fname_base = os.path.splitext(fname_in)[0]
1332        fname_out = str(fname_base) + '.c'
1333    else:
1334        fname_out = sys.argv[2]
1335
1336    with open(fname_in, "r") as python_file:
1337        code = python_file.read()
1338    name = "gauss"
1339    code = (code
1340            .replace(name+'.n', 'GAUSS_N')
1341            .replace(name+'.z', 'GAUSS_Z')
1342            .replace(name+'.w', 'GAUSS_W')
1343            .replace('if __name__ == "__main__"', "def main()")
1344           )
1345
1346    translation, warnings = translate([(code, fname_in, 1)])
1347    c_code = "".join(translation)
1348    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])")
1349
1350    with open(fname_out, "w") as file_out:
1351        file_out.write(C_HEADER)
1352        file_out.write(c_code)
1353
1354    if warnings:
1355        print("\n".join(warnings))
1356    #print("...Done")
1357
1358if __name__ == "__main__":
1359    main()
Note: See TracBrowser for help on using the repository browser.