source: sasmodels/sasmodels/py2c.py @ 6f91c91

Last change on this file since 6f91c91 was 6f91c91, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

track line number for c header

  • Property mode set to 100644
File size: 44.8 KB
Line 
1r"""
2py2c
3~~~~
4
5Convert simple numeric python code into C code.
6
7This code is intended to translate direct algorithms for scientific code
8(mostly if statements and for loops operating on double precision values)
9into C code. Unlike projects like numba, cython, pypy and nuitka, the
10:func:`translate` function returns the corresponding C which can then be
11compiled with tinycc or sent to the GPU using CUDA or OpenCL.
12
13There is special handling certain constructs, such as *for i in range* and
14small integer powers.
15
16**TODO: make a nice list of supported constructs***
17
18Imports are not supported, but they are at least ignored so that properly
19constructed code can be run via python or translated to C without change.
20
21Most other python constructs are **not** supported:
22* classes
23* builtin types (dict, set, list)
24* exceptions
25* with context
26* del
27* yield
28* async
29* list slicing
30* multiple return values
31* "is/is not", "in/not in" conditionals
32
33There is limited support for list and list comprehensions, so long as they
34can be represented by a fixed array whose size is known at compile time, and
35they are small enough to be stored on the stack.
36
37Variables definition in C
38-------------------------
39Defining variables within the translate function is a bit of a guess work,
40using following rules:
41*   By default, a variable is a 'double'.
42*   Variable in a for loop is an int.
43*   Variable that is references with brackets is an array of doubles. The
44    variable within the brackets is integer. For example, in the
45    reference 'var1[var2]', var1 is a double array, and var2 is an integer.
46*   Assignment to an argument makes that argument an array, and the index
47    in that assignment is 0.
48    For example, the following python code::
49        def func(arg1, arg2):
50            arg2 = 17.
51    is translated to the following C code::
52        double func(double arg1)
53        {
54            arg2[0] = 17.0;
55        }
56    For example, the following python code is translated to the
57    following C code::
58
59        def func(arg1, arg2):          double func(double arg1) {
60            arg2 = 17.                      arg2[0] = 17.0;
61                                        }
62*   All functions are defined as double, even if there is no
63    return statement.
64
65Debugging
66---------
67
68*print* is partially supported using a simple regular expression. This
69requires a stylized form. Be sure to use print as a function instead of
70the print statement. If you are including substition variables, use the
71% string substitution style. Include parentheses around the substitution
72tuple, even if there is only one item; do not include the final comma even
73if it is a single item (yes, it won't be a tuple, but it makes the regexp
74much simpler). Keep the item on a single line. Here are three forms that work::
75
76    print("x") => printf("x\n");
77    print("x %g"%(a)) => printf("x %g\n", a);
78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c);
79
80You can generate *main* using the *if __name__ == "__main__":* construct.
81This does a simple substitution with "def main():" before translation and
82a substitution with "int main(int argc, double *argv[])" after translation.
83The result is that the content of the *if* block becomes the content of *main*.
84Along with the print statement, you can run and test a translation standalone
85using::
86
87    python py2c.py source.py
88    cc source.c
89    ./a.out
90
91Known issues
92------------
93The following constructs may cause problems:
94
95* implicit arrays: possible namespace collision for variable "vec#"
96* swap fails: "x,y = y,x" will set x==y
97* top-level statements: code outside a function body causes errors
98* line number skew: each statement should be tagged with its own #line
99  to avoid skew as comments are skipped and loop bodies are wrapped with
100  braces, etc.
101
102References
103----------
104
105Based on a variant of codegen.py:
106
107    https://github.com/andreif/codegen
108    :copyright: Copyright 2008 by Armin Ronacher.
109    :license: BSD.
110"""
111
112# Update Notes
113# ============
114# 2017-11-22, OE: Each 'visit_*' method is to build a C statement string. It
115#                 shold insert 4 blanks per indentation level. The 'body'
116#                 method will combine all the strings, by adding the
117#                 'current_statement' to the c_proc string list
118# 2017-11-22, OE: variables, argument definition implemented.  Note: An
119#                 argument is considered an array if it is the target of an
120#                 assignment. In that case it is translated to <var>[0]
121# 2017-11-27, OE: 'pow' basicly working
122# 2017-12-07, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
123# 2017-12-07, OE: Power function, including special cases of
124#                 square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
125#                 translate_power, called from visit_BinOp
126# 2017-12-07, OE: Translation of integer division, '\\' in python, implemented
127#                 in translate_integer_divide, called from visit_BinOp
128# 2017-12-07, OE: C variable definition handled in 'define_c_vars'
129#               : Python integer division, '//', translated to C in
130#                 'translate_integer_divide'
131# 2017-12-15, OE: Precedence maintained by writing opening and closing
132#                 parenthesesm '(',')', in procedure 'visit_BinOp'.
133# 2017-12-18, OE: Added call to 'add_current_line()' at the beginning
134#                 of visit_Return
135# 2018-01-03, PK: Update interface for use in sasmodels
136# 2018-01-03, PK: support "expr if cond else expr" syntax
137# 2018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y))
138# 2018-01-03, PK: True/False => true/false
139# 2018-01-03, PK: f(x) was introducing an extra semicolon
140# 2018-01-03, PK: simplistic print function, for debugging
141# 2018-01-03, PK: while expr: ... => while (expr) { ... }
142# 2018-01-04, OE: Fixed bug in 'visit_If': visiting node.orelse in case else exists.
143
144from __future__ import print_function
145
146import sys
147import ast
148from ast import NodeVisitor
149from inspect import currentframe, getframeinfo
150
151try: # for debugging, astor lets us print out the node as python
152    import astor
153except ImportError:
154    pass
155
156BINOP_SYMBOLS = {}
157BINOP_SYMBOLS[ast.Add] = '+'
158BINOP_SYMBOLS[ast.Sub] = '-'
159BINOP_SYMBOLS[ast.Mult] = '*'
160BINOP_SYMBOLS[ast.Div] = '/'
161BINOP_SYMBOLS[ast.Mod] = '%'
162BINOP_SYMBOLS[ast.Pow] = '**'
163BINOP_SYMBOLS[ast.LShift] = '<<'
164BINOP_SYMBOLS[ast.RShift] = '>>'
165BINOP_SYMBOLS[ast.BitOr] = '|'
166BINOP_SYMBOLS[ast.BitXor] = '^'
167BINOP_SYMBOLS[ast.BitAnd] = '&'
168BINOP_SYMBOLS[ast.FloorDiv] = '//'
169
170BOOLOP_SYMBOLS = {}
171BOOLOP_SYMBOLS[ast.And] = '&&'
172BOOLOP_SYMBOLS[ast.Or] = '||'
173
174CMPOP_SYMBOLS = {}
175CMPOP_SYMBOLS[ast.Eq] = '=='
176CMPOP_SYMBOLS[ast.NotEq] = '!='
177CMPOP_SYMBOLS[ast.Lt] = '<'
178CMPOP_SYMBOLS[ast.LtE] = '<='
179CMPOP_SYMBOLS[ast.Gt] = '>'
180CMPOP_SYMBOLS[ast.GtE] = '>='
181CMPOP_SYMBOLS[ast.Is] = 'is'
182CMPOP_SYMBOLS[ast.IsNot] = 'is not'
183CMPOP_SYMBOLS[ast.In] = 'in'
184CMPOP_SYMBOLS[ast.NotIn] = 'not in'
185
186UNARYOP_SYMBOLS = {}
187UNARYOP_SYMBOLS[ast.Invert] = '~'
188UNARYOP_SYMBOLS[ast.Not] = 'not'
189UNARYOP_SYMBOLS[ast.UAdd] = '+'
190UNARYOP_SYMBOLS[ast.USub] = '-'
191
192
193# TODO: should not allow eval of arbitrary python
194def isevaluable(s):
195    try:
196        eval(s)
197        return True
198    except Exception:
199        return False
200
201def render_expression(tree):
202    generator = SourceGenerator()
203    generator.visit(tree)
204    c_code = "".join(generator.current_statement)
205    return c_code
206
207class SourceGenerator(NodeVisitor):
208    """This visitor is able to transform a well formed syntax tree into python
209    sourcecode.  For more details have a look at the docstring of the
210    `node_to_source` function.
211    """
212
213    def __init__(self, indent_with="    ", constants=None, fname=None, lineno=0):
214        self.indent_with = indent_with
215        self.indentation = 0
216
217        # for C
218        self.c_proc = []
219        self.signature_line = 0
220        self.arguments = []
221        self.current_function = ""
222        self.fname = fname
223        self.lineno_offset = lineno
224        self.warnings = []
225        self.current_statement = ""
226        # TODO: use set rather than list for c_vars, ...
227        self.c_vars = []
228        self.c_int_vars = []
229        self.c_pointers = []
230        self.c_dcl_pointers = []
231        self.c_functions = []
232        self.c_vectors = []
233        self.c_constants = constants if constants is not None else {}
234        self.in_expr = False
235        self.in_subref = False
236        self.in_subscript = False
237        self.tuples = []
238        self.required_functions = []
239        self.visited_args = False
240
241    def write_c(self, statement):
242        # TODO: build up as a list rather than adding to string
243        self.current_statement += statement
244
245    def write_python(self, x):
246        raise NotImplementedError("shouldn't be trying to write pythnon")
247
248    def add_c_line(self, line):
249        indentation = self.indent_with * self.indentation
250        self.c_proc.append("".join((indentation, line, "\n")))
251
252    def add_current_line(self):
253        if self.current_statement:
254            self.add_c_line(self.current_statement)
255            self.current_statement = ''
256
257    def add_unique_var(self, new_var):
258        if new_var not in self.c_vars:
259            self.c_vars.append(str(new_var))
260
261    def write_sincos(self, node):
262        angle = str(node.args[0].id)
263        self.write_c(node.args[1].id + " = sin(" + angle + ");")
264        self.add_current_line()
265        self.write_c(node.args[2].id + " = cos(" + angle + ");")
266        self.add_current_line()
267        for arg in node.args:
268            self.add_unique_var(arg.id)
269
270    def track_lineno(self, node):
271        #print("newline", node, [s for s in dir(node) if not s.startswith('_')])
272        if hasattr(node, 'lineno'):
273            line = '#line %d "%s"\n' % (node.lineno+self.lineno_offset-1, self.fname)
274            self.c_proc.append(line)
275
276    def body(self, statements):
277        if self.current_statement:
278            self.add_current_line()
279        self.new_line = True
280        self.indentation += 1
281        for stmt in statements:
282            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
283            #    target_name = stmt.targets[0].id # target name needed for debug only
284            self.visit(stmt)
285        self.add_current_line() # just for breaking point. to be deleted.
286        self.indentation -= 1
287
288    def body_or_else(self, node):
289        self.body(node.body)
290        if node.orelse:
291            self.unsupported(node, "for...else/while...else not supported")
292
293            self.track_lineno(node)
294            self.write_c('else:')
295            self.body(node.orelse)
296
297    def signature(self, node):
298        want_comma = []
299        def write_comma():
300            if want_comma:
301                self.write_c(', ')
302            else:
303                want_comma.append(True)
304
305        # for C
306        for arg in node.args:
307            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
308            try:
309                arg_name = arg.arg
310            except AttributeError:
311                arg_name = arg.id
312            self.arguments.append(arg_name)
313
314        padding = [None] *(len(node.args) - len(node.defaults))
315        for arg, default in zip(node.args, padding + node.defaults):
316            if default is not None:
317                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
318                try:
319                    arg_name = arg.arg
320                except AttributeError:
321                    arg_name = arg.id
322                w_str = ("C does not support default parameters: %s=%s"
323                         % (arg_name, str(default.n)))
324                self.warnings.append(w_str)
325
326    def decorators(self, node):
327        if node.decorator_list:
328            self.unsupported(node.decorator_list[0])
329        for decorator in node.decorator_list:
330            self.trac_lineno(decorator)
331            self.write_python('@')
332            self.visit(decorator)
333
334    # Statements
335
336    def visit_Assert(self, node):
337        self.unsupported(node)
338
339        self.track_lineno(node)
340        self.write_c('assert ')
341        self.visit(node.test)
342        if node.msg is not None:
343            self.write_python(', ')
344            self.visit(node.msg)
345
346    def define_c_vars(self, target):
347        if hasattr(target, 'id'):
348        # a variable is considered an array if it apears in the agrument list
349        # and being assigned to. For example, the variable p in the following
350        # sniplet is a pointer, while q is not
351        # def somefunc(p, q):
352        #  p = q + 1
353        #  return
354        #
355            if target.id not in self.c_vars:
356                if target.id in self.arguments:
357                    idx = self.arguments.index(target.id)
358                    new_target = self.arguments[idx] + "[0]"
359                    if new_target not in self.c_pointers:
360                        target.id = new_target
361                        self.c_pointers.append(self.arguments[idx])
362                else:
363                    self.c_vars.append(target.id)
364
365    def add_semi_colon(self):
366        #semi_pos = self.current_statement.find(';')
367        #if semi_pos >= 0:
368        #    self.current_statement = self.current_statement.replace(';', '')
369        self.write_c(';')
370
371    def visit_Assign(self, node):
372        self.add_current_line()
373        self.track_lineno(node)
374        self.in_expr = True
375        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
376            if idx:
377                self.write_c(' = ')
378            self.define_c_vars(target)
379            self.visit(target)
380        # Capture assigned tuple names, if any
381        targets = self.tuples[:]
382        del self.tuples[:]
383        self.write_c(' = ')
384        self.visited_args = False
385        self.visit(node.value)
386        self.add_semi_colon()
387        self.add_current_line()
388        # Assign tuples to tuples, if any
389        # TODO: doesn't handle swap:  a,b = b,a
390        for target, item in zip(targets, self.tuples):
391            self.visit(target)
392            self.write_c(' = ')
393            self.visit(item)
394            self.add_semi_colon()
395            self.add_current_line()
396        #if self.is_sequence and not self.visited_args:
397        #    for target in node.targets:
398        #        if hasattr(target, 'id'):
399        #            if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
400        #                if target.id not in self.c_dcl_pointers:
401        #                    self.c_dcl_pointers.append(target.id)
402        #                    if target.id in self.c_vars:
403        #                        self.c_vars.remove(target.id)
404        self.current_statement = ''
405        self.in_expr = False
406
407    def visit_AugAssign(self, node):
408        if node.target.id not in self.c_vars:
409            if node.target.id not in self.arguments:
410                self.c_vars.append(node.target.id)
411        self.in_expr = True
412        self.visit(node.target)
413        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
414        self.visit(node.value)
415        self.add_semi_colon()
416        self.in_expr = False
417        self.add_current_line()
418
419    def visit_ImportFrom(self, node):
420        return  # import ignored
421        self.track_lineno(node)
422        self.write_python('from %s%s import ' %('.' * node.level, node.module))
423        for idx, item in enumerate(node.names):
424            if idx:
425                self.write_python(', ')
426            self.write_python(item)
427
428    def visit_Import(self, node):
429        return  # import ignored
430        self.track_lineno(node)
431        for item in node.names:
432            self.write_python('import ')
433            self.visit(item)
434
435    def visit_Expr(self, node):
436        #self.in_expr = True
437        #self.track_lineno(node)
438        self.generic_visit(node)
439        #self.in_expr = False
440
441    def write_c_pointers(self, start_var):
442        if self.c_dcl_pointers:
443            var_list = []
444            for c_ptr in self.c_dcl_pointers:
445                if c_ptr not in self.arguments:
446                    var_list.append("*" + c_ptr)
447                if c_ptr in self.c_vars:
448                    self.c_vars.remove(c_ptr)
449            if var_list:
450                c_dcl = "    double " + ", ".join(var_list) + ";\n"
451                self.c_proc.insert(start_var, c_dcl)
452                start_var += 1
453        return start_var
454
455    def insert_c_vars(self, start_var):
456        have_decls = False
457        start_var = self.write_c_pointers(start_var)
458        if self.c_int_vars:
459            for var in self.c_int_vars:
460                if var in self.c_vars:
461                    self.c_vars.remove(var)
462            decls = ", ".join(self.c_int_vars)
463            self.c_proc.insert(start_var, "    int " + decls + ";\n")
464            have_decls = True
465            start_var += 1
466
467        if self.c_vars:
468            decls = ", ".join(self.c_vars)
469            self.c_proc.insert(start_var, "    double " + decls + ";\n")
470            have_decls = True
471            start_var += 1
472
473        if self.c_vectors:
474            for vec_number, vec_value  in enumerate(self.c_vectors):
475                name = "vec" + str(vec_number + 1)
476                decl = "    double " + name + "[] = {" + vec_value + "};"
477                self.c_proc.insert(start_var, decl + "\n")
478                start_var += 1
479
480        del self.c_vars[:]
481        del self.c_int_vars[:]
482        del self.c_vectors[:]
483        del self.c_pointers[:]
484        del self.c_dcl_pointers[:]
485        if have_decls:
486            self.c_proc.insert(start_var, "\n")
487
488    def insert_signature(self):
489        arg_decls = []
490        for arg in self.arguments:
491            decl = "double " + arg
492            if arg in self.c_pointers:
493                decl += "[]"
494            arg_decls.append(decl)
495        args_str = ", ".join(arg_decls)
496        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
497        if self.signature_line >= 0:
498            self.c_proc.insert(self.signature_line, method_sig)
499
500    def visit_FunctionDef(self, node):
501        if self.current_function:
502            self.unsupported(node, "function within a function")
503        self.current_function = node.name
504
505        # remember the location of the next warning that will be inserted
506        # so that we can stuff the function name ahead of the warning list
507        # if any warnings are generated by the function.
508        warning_index = len(self.warnings)
509
510        self.decorators(node)
511        self.track_lineno(node)
512        self.arguments = []
513        self.visit(node.args)
514        # for C
515        self.signature_line = len(self.c_proc)
516        self.add_c_line("\n{")
517        start_vars = len(self.c_proc) + 1
518        self.body(node.body)
519        self.add_c_line("}\n")
520        self.insert_signature()
521        self.insert_c_vars(start_vars)
522        del self.c_pointers[:]
523        self.current_function = ""
524        if warning_index != len(self.warnings):
525            self.warnings.insert(warning_index, "Warning in function '" + node.name + "':")
526
527    def visit_ClassDef(self, node):
528        self.unsupported(node)
529
530        have_args = []
531        def paren_or_comma():
532            if have_args:
533                self.write_python(', ')
534            else:
535                have_args.append(True)
536                self.write_python('(')
537
538        self.decorators(node)
539        self.track_lineno(node)
540        self.write_python('class %s' % node.name)
541        for base in node.bases:
542            paren_or_comma()
543            self.visit(base)
544        # CRUFT: python 2.6 does not have "keywords" attribute
545        if hasattr(node, 'keywords'):
546            for keyword in node.keywords:
547                paren_or_comma()
548                self.write_python(keyword.arg + '=')
549                self.visit(keyword.value)
550            if node.starargs is not None:
551                paren_or_comma()
552                self.write_python('*')
553                self.visit(node.starargs)
554            if node.kwargs is not None:
555                paren_or_comma()
556                self.write_python('**')
557                self.visit(node.kwargs)
558        self.write_python(have_args and '):' or ':')
559        self.body(node.body)
560
561    def visit_If(self, node):
562
563        self.track_lineno(node)
564        self.write_c('if ')
565        self.in_expr = True
566        self.visit(node.test)
567        self.in_expr = False
568        self.write_c(' {')
569        self.body(node.body)
570        self.add_c_line('}')
571        while True:
572            else_ = node.orelse
573            if len(else_) == 0:
574                break
575            #elif hasattr(else_, 'orelse'):
576            elif len(else_) == 1 and isinstance(else_[0], ast.If):
577                node = else_[0]
578                self.track_lineno(node)
579                self.write_c('else if ')
580                self.in_expr = True
581                self.visit(node.test)
582                self.in_expr = False
583                self.write_c(' {')
584                self.body(node.body)
585                self.add_current_line()
586                self.add_c_line('}')
587                #break
588            else:
589                self.track_lineno(else_)
590                self.write_c('else {')
591                self.body(else_)
592                self.add_c_line('}')
593                break
594
595    def get_for_range(self, node):
596        stop = ""
597        start = '0'
598        step = '1'
599        for_args = []
600        temp_statement = self.current_statement
601        self.current_statement = ''
602        for arg in node.iter.args:
603            self.visit(arg)
604            for_args.append(self.current_statement)
605            self.current_statement = ''
606        self.current_statement = temp_statement
607        if len(for_args) == 1:
608            stop = for_args[0]
609        elif len(for_args) == 2:
610            start = for_args[0]
611            stop = for_args[1]
612        elif len(for_args) == 3:
613            start = for_args[0]
614            stop = for_args[1]
615            start = for_args[2]
616        else:
617            raise("Ilegal for loop parameters")
618        return start, stop, step
619
620    def add_c_int_var(self, name):
621        if name not in self.c_int_vars:
622            self.c_int_vars.append(name)
623
624    def visit_For(self, node):
625        # node: for iterator is stored in node.target.
626        # Iterator name is in node.target.id.
627        self.add_current_line()
628        fForDone = False
629        self.current_statement = ''
630        if hasattr(node.iter, 'func'):
631            if hasattr(node.iter.func, 'id'):
632                if node.iter.func.id == 'range':
633                    self.visit(node.target)
634                    iterator = self.current_statement
635                    self.current_statement = ''
636                    self.add_c_int_var(iterator)
637                    start, stop, step = self.get_for_range(node)
638                    self.write_c("for (" + iterator + "=" + str(start) +
639                                 " ; " + iterator + " < " + str(stop) +
640                                 " ; " + iterator + " += " + str(step) + ") {")
641                    self.body_or_else(node)
642                    self.write_c("}")
643                    fForDone = True
644        if not fForDone:
645            # Generate the statement that is causing the error
646            self.current_statement = ''
647            self.write_c('for ')
648            self.visit(node.target)
649            self.write_c(' in ')
650            self.visit(node.iter)
651            self.write_c(':')
652            # report the error
653            self.unsupported(node, "unsupported " + self.current_statement)
654
655    def visit_While(self, node):
656        self.track_lineno(node)
657        self.write_c('while ')
658        self.visit(node.test)
659        self.write_c(' {')
660        self.body_or_else(node)
661        self.write_c('}')
662        self.add_current_line()
663
664    def visit_With(self, node):
665        self.unsupported(node)
666
667        self.track_lineno(node)
668        self.write_python('with ')
669        self.visit(node.context_expr)
670        if node.optional_vars is not None:
671            self.write_python(' as ')
672            self.visit(node.optional_vars)
673        self.write_python(':')
674        self.body(node.body)
675
676    def visit_Pass(self, node):
677        #self.track_lineno(node)
678        #self.write_python('pass')
679        pass
680
681    def visit_Print(self, node):
682        self.unsupported(node)
683
684        # CRUFT: python 2.6 only
685        self.track_lineno(node)
686        self.write_c('print ')
687        want_comma = False
688        if node.dest is not None:
689            self.write_c(' >> ')
690            self.visit(node.dest)
691            want_comma = True
692        for value in node.values:
693            if want_comma:
694                self.write_c(', ')
695            self.visit(value)
696            want_comma = True
697        if not node.nl:
698            self.write_c(',')
699
700    def visit_Delete(self, node):
701        self.unsupported(node)
702
703        self.track_lineno(node)
704        self.write_python('del ')
705        for idx, target in enumerate(node):
706            if idx:
707                self.write_python(', ')
708            self.visit(target)
709
710    def visit_TryExcept(self, node):
711        self.unsupported(node)
712
713        self.track_linno(node)
714        self.write_python('try:')
715        self.body(node.body)
716        for handler in node.handlers:
717            self.visit(handler)
718
719    def visit_TryFinally(self, node):
720        self.unsupported(node)
721
722        self.track_lineno(node)
723        self.write_python('try:')
724        self.body(node.body)
725        self.track_lineno(node)
726        self.write_python('finally:')
727        self.body(node.finalbody)
728
729    def visit_Global(self, node):
730        self.unsupported(node)
731
732        self.track_lineno(node)
733        self.write_python('global ' + ', '.join(node.names))
734
735    def visit_Nonlocal(self, node):
736        self.track_lineno(node)
737        self.write_python('nonlocal ' + ', '.join(node.names))
738
739    def visit_Return(self, node):
740        self.add_current_line()
741        self.track_lineno(node)
742        self.in_expr = True
743        if node.value is None:
744            self.write_c('return')
745        else:
746            self.write_c('return ')
747            self.visit(node.value)
748        self.add_semi_colon()
749        self.in_expr = False
750        self.add_c_line(self.current_statement)
751        self.current_statement = ''
752
753    def visit_Break(self, node):
754        self.track_lineno(node)
755        self.write_c('break')
756
757    def visit_Continue(self, node):
758        self.track_lineno(node)
759        self.write_c('continue')
760
761    def visit_Raise(self, node):
762        self.unsupported(node)
763
764        # CRUFT: Python 2.6 / 3.0 compatibility
765        self.track_lineno(node)
766        self.write_python('raise')
767        if hasattr(node, 'exc') and node.exc is not None:
768            self.write_python(' ')
769            self.visit(node.exc)
770            if node.cause is not None:
771                self.write_python(' from ')
772                self.visit(node.cause)
773        elif hasattr(node, 'type') and node.type is not None:
774            self.visit(node.type)
775            if node.inst is not None:
776                self.write_python(', ')
777                self.visit(node.inst)
778            if node.tback is not None:
779                self.write_python(', ')
780                self.visit(node.tback)
781
782    # Expressions
783
784    def visit_Attribute(self, node):
785        self.unsupported(node, "attribute reference a.b not supported")
786
787        self.visit(node.value)
788        self.write_python('.' + node.attr)
789
790    def visit_Call(self, node):
791        want_comma = []
792        def write_comma():
793            if want_comma:
794                self.write_c(', ')
795            else:
796                want_comma.append(True)
797        if hasattr(node.func, 'id'):
798            if node.func.id not in self.c_functions:
799                self.c_functions.append(node.func.id)
800            if node.func.id == 'abs':
801                self.write_c("fabs ")
802            elif node.func.id == 'int':
803                self.write_c('(int) ')
804            elif node.func.id == "SINCOS":
805                self.write_sincos(node)
806                return
807            else:
808                self.visit(node.func)
809        else:
810            self.visit(node.func)
811        self.write_c('(')
812        for arg in node.args:
813            write_comma()
814            self.visited_args = True
815            self.visit(arg)
816        for keyword in node.keywords:
817            write_comma()
818            self.write_c(keyword.arg + '=')
819            self.visit(keyword.value)
820        if hasattr(node, 'starargs'):
821            if node.starargs is not None:
822                write_comma()
823                self.write_c('*')
824                self.visit(node.starargs)
825        if hasattr(node, 'kwargs'):
826            if node.kwargs is not None:
827                write_comma()
828                self.write_c('**')
829                self.visit(node.kwargs)
830        self.write_c(')')
831        if not self.in_expr:
832            self.add_semi_colon()
833
834    TRANSLATE_CONSTANTS = {
835        # python 2 uses normal name references through vist_Name
836        'True': 'true',
837        'False': 'false',
838        'None': 'NULL',  # "None" will probably fail for other reasons
839        # python 3 uses NameConstant
840        True: 'true',
841        False: 'false',
842        None: 'NULL',  # "None" will probably fail for other reasons
843        }
844
845    def visit_Name(self, node):
846        translation = self.TRANSLATE_CONSTANTS.get(node.id, None)
847        if translation:
848            self.write_c(translation)
849            return
850        self.write_c(node.id)
851        if node.id in self.c_pointers and not self.in_subref:
852            self.write_c("[0]")
853        name = ""
854        sub = node.id.find("[")
855        if sub > 0:
856            name = node.id[0:sub].strip()
857        else:
858            name = node.id
859        # add variable to c_vars if it ins't there yet, not an argument and not a number
860        if (name not in self.c_functions and name not in self.c_vars and
861                name not in self.c_int_vars and name not in self.arguments and
862                name not in self.c_constants and not name.isdigit()):
863            if self.in_subscript:
864                self.add_c_int_var(node.id)
865            else:
866                self.c_vars.append(node.id)
867
868    def visit_NameConstant(self, node):
869        translation = self.TRANSLATE_CONSTANTS.get(node.value, None)
870        if translation is not None:
871            self.write_c(translation)
872        else:
873            self.unsupported(node, "don't know how to translate %r"%node.value)
874
875    def visit_Str(self, node):
876        s = node.s
877        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
878        self.write_c('"')
879        self.write_c(s)
880        self.write_c('"')
881
882    def visit_Bytes(self, node):
883        s = node.s
884        s = s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
885        self.write_c('"')
886        self.write_c(s)
887        self.write_c('"')
888
889    def visit_Num(self, node):
890        self.write_c(repr(node.n))
891
892    def visit_Tuple(self, node):
893        for idx, item in enumerate(node.elts):
894            if idx:
895                self.tuples.append(item)
896            else:
897                self.visit(item)
898
899    def visit_List(self, node):
900        #self.unsupported(node)
901        #print("visiting", node)
902        #print(astor.to_source(node))
903        #print(node.elts)
904        exprs = [render_expression(item) for item in node.elts]
905        if exprs:
906            self.c_vectors.append(', '.join(exprs))
907            vec_name = "vec"  + str(len(self.c_vectors))
908            self.write_c(vec_name)
909
910    def visit_Set(self, node):
911        self.unsupported(node)
912
913    def visit_Dict(self, node):
914        self.unsupported(node)
915
916        self.write_python('{')
917        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
918            if idx:
919                self.write_python(', ')
920            self.visit(key)
921            self.write_python(': ')
922            self.visit(value)
923        self.write_python('}')
924
925    def get_special_power(self, string):
926        function_name = ''
927        is_negative_exp = False
928        if isevaluable(str(self.current_statement)):
929            exponent = eval(string)
930            is_negative_exp = exponent < 0
931            abs_exponent = abs(exponent)
932            if abs_exponent == 2:
933                function_name = "square"
934            elif abs_exponent == 3:
935                function_name = "cube"
936            elif abs_exponent == 0.5:
937                function_name = "sqrt"
938            elif abs_exponent == 1.0/3.0:
939                function_name = "cbrt"
940        if function_name == '':
941            function_name = "pow"
942        return function_name, is_negative_exp
943
944    def translate_power(self, node):
945        # get exponent by visiting the right hand argument.
946        function_name = "pow"
947        temp_statement = self.current_statement
948        # 'visit' functions write the results to the 'current_statement' class memnber
949        # Here, a temporary variable, 'temp_statement', is used, that enables the
950        # use of the 'visit' function
951        self.current_statement = ''
952        self.visit(node.right)
953        exponent = self.current_statement.replace(' ', '')
954        function_name, is_negative_exp = self.get_special_power(self.current_statement)
955        self.current_statement = temp_statement
956        if is_negative_exp:
957            self.write_c("1.0 /(")
958        self.write_c(function_name + "(")
959        self.visit(node.left)
960        if function_name == "pow":
961            self.write_c(", ")
962            self.visit(node.right)
963        self.write_c(")")
964        if is_negative_exp:
965            self.write_c(")")
966        #self.write_c(" ")
967
968    def translate_integer_divide(self, node):
969        self.write_c("(int)((")
970        self.visit(node.left)
971        self.write_c(")/(")
972        self.visit(node.right)
973        self.write_c("))")
974
975    def translate_float_divide(self, node):
976        self.write_c("((double)(")
977        self.visit(node.left)
978        self.write_c(")/(double)(")
979        self.visit(node.right)
980        self.write_c("))")
981
982    def visit_BinOp(self, node):
983        self.write_c("(")
984        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
985            self.translate_power(node)
986        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
987            self.translate_integer_divide(node)
988        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]:
989            self.translate_float_divide(node)
990        else:
991            self.visit(node.left)
992            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
993            self.visit(node.right)
994        self.write_c(")")
995
996    # for C
997    def visit_BoolOp(self, node):
998        self.write_c('(')
999        for idx, value in enumerate(node.values):
1000            if idx:
1001                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
1002            self.visit(value)
1003        self.write_c(')')
1004
1005    def visit_Compare(self, node):
1006        self.write_c('(')
1007        self.visit(node.left)
1008        for op, right in zip(node.ops, node.comparators):
1009            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1010            self.visit(right)
1011        self.write_c(')')
1012
1013    def visit_UnaryOp(self, node):
1014        self.write_c('(')
1015        op = UNARYOP_SYMBOLS[type(node.op)]
1016        self.write_c(op)
1017        if op == 'not':
1018            self.write_c(' ')
1019        self.visit(node.operand)
1020        self.write_c(')')
1021
1022    def visit_Subscript(self, node):
1023        if node.value.id not in self.c_constants:
1024            if node.value.id not in self.c_pointers:
1025                self.c_pointers.append(node.value.id)
1026        self.in_subref = True
1027        self.visit(node.value)
1028        self.in_subref = False
1029        self.write_c('[')
1030        self.in_subscript = True
1031        self.visit(node.slice)
1032        self.in_subscript = False
1033        self.write_c(']')
1034
1035    def visit_Slice(self, node):
1036        if node.lower is not None:
1037            self.visit(node.lower)
1038        self.write_python(':')
1039        if node.upper is not None:
1040            self.visit(node.upper)
1041        if node.step is not None:
1042            self.write_python(':')
1043            if not(isinstance(node.step, Name) and node.step.id == 'None'):
1044                self.visit(node.step)
1045
1046    def visit_ExtSlice(self, node):
1047        for idx, item in node.dims:
1048            if idx:
1049                self.write_python(', ')
1050            self.visit(item)
1051
1052    def visit_Yield(self, node):
1053        self.unsupported(node)
1054
1055        self.write_python('yield ')
1056        self.visit(node.value)
1057
1058    def visit_Lambda(self, node):
1059        self.unsupported(node)
1060
1061        self.write_python('lambda ')
1062        self.visit(node.args)
1063        self.write_python(': ')
1064        self.visit(node.body)
1065
1066    def visit_Ellipsis(self, node):
1067        self.unsupported(node)
1068
1069        self.write_python('Ellipsis')
1070
1071    def generator_visit(left, right):
1072        def visit(self, node):
1073            self.write_python(left)
1074            self.write_c(left)
1075            self.visit(node.elt)
1076            for comprehension in node.generators:
1077                self.visit(comprehension)
1078            self.write_c(right)
1079            #self.write_python(right)
1080        return visit
1081
1082    visit_ListComp = generator_visit('[', ']')
1083    visit_GeneratorExp = generator_visit('(', ')')
1084    visit_SetComp = generator_visit('{', '}')
1085    del generator_visit
1086
1087    def visit_DictComp(self, node):
1088        self.unsupported(node)
1089
1090        self.write_python('{')
1091        self.visit(node.key)
1092        self.write_python(': ')
1093        self.visit(node.value)
1094        for comprehension in node.generators:
1095            self.visit(comprehension)
1096        self.write_python('}')
1097
1098    def visit_IfExp(self, node):
1099        self.write_c('((')
1100        self.visit(node.test)
1101        self.write_c(')?(')
1102        self.visit(node.body)
1103        self.write_c('):(')
1104        self.visit(node.orelse)
1105        self.write_c('))')
1106
1107    def visit_Starred(self, node):
1108        self.write_c('*')
1109        self.visit(node.value)
1110
1111    def visit_Repr(self, node):
1112        # CRUFT: python 2.6 only
1113        self.write_c('`')
1114        self.visit(node.value)
1115        self.write_python('`')
1116
1117    # Helper Nodes
1118
1119    def visit_alias(self, node):
1120        self.unsupported(node)
1121
1122        self.write_python(node.name)
1123        if node.asname is not None:
1124            self.write_python(' as ' + node.asname)
1125
1126    def visit_comprehension(self, node):
1127        self.unsupported(node)
1128
1129        self.write_c(' for ')
1130        self.visit(node.target)
1131        self.write_c(' in ')
1132        #self.write_python(' in ')
1133        self.visit(node.iter)
1134        if node.ifs:
1135            for if_ in node.ifs:
1136                self.write_python(' if ')
1137                self.visit(if_)
1138
1139    def visit_arguments(self, node):
1140        self.signature(node)
1141
1142    def unsupported(self, node, message=None):
1143        if hasattr(node, "value"):
1144            lineno = node.value.lineno
1145        elif hasattr(node, "iter"):
1146            lineno = node.iter.lineno
1147        else:
1148            #print(dir(node))
1149            lineno = 0
1150
1151        lineno += self.lineno_offset
1152        if self.fname:
1153            location = "%s(%d)" % (self.fname, lineno)
1154        else:
1155            location = "%d" % (self.fname, lineno)
1156        if self.current_function:
1157            location += ", function %s" % self.current_function
1158        if message is None:
1159            message = node.__class__.__name__ + " syntax not supported"
1160        raise SyntaxError("[%s] %s" % (location, message))
1161
1162def print_function(f=None):
1163    """
1164    Print out the code for the function
1165    """
1166    # Include some comments to see if they get printed
1167    import ast
1168    import inspect
1169    if f is not None:
1170        tree = ast.parse(inspect.getsource(f))
1171        tree_source = to_source(tree)
1172        print(tree_source)
1173
1174def define_constant(name, value, block_size=1):
1175    # type: (str, any, int) -> str
1176    """
1177    Convert a python constant into a C constant of the same name.
1178
1179    Returns the C declaration of the constant as a string, possibly containing
1180    line feeds.  The string will not be indented.
1181
1182    Supports int, double and sequences of double.
1183    """
1184    const = "constant "  # OpenCL needs globals to be constant
1185    if isinstance(value, int):
1186        parts = [const + "int ", name, " = ", "%d"%value, ";\n"]
1187    elif isinstance(value, float):
1188        parts = [const + "double ", name, " = ", "%.15g"%value, ";\n"]
1189    else:
1190        try:
1191            len(value)
1192        except TypeError:
1193            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1194        # extend constant arrays to a multiple of 4; not sure if this
1195        # is necessary, but some OpenCL targets broke if the number
1196        # of parameters in the parameter table was not a multiple of 4,
1197        # so do it for all constant arrays to be safe.
1198        if len(value)%block_size != 0:
1199            value = list(value) + [0.]*(block_size - len(value)%block_size)
1200        elements = ["%.15g"%v for v in value]
1201        parts = [const + "double ", name, "[]", " = ",
1202                 "{\n   ", ", ".join(elements), "\n};\n"]
1203
1204    return "".join(parts)
1205
1206
1207# Modified from the following:
1208#
1209#    http://code.activestate.com/recipes/578272-topological-sort/
1210#    Copyright (C) 2012 Sam Denton
1211#    License: MIT
1212def ordered_dag(dag):
1213    # type: (Dict[T, Set[T]]) -> Iterator[T]
1214    """
1215    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1216    in order such that every key occurs after the keys it depends upon.
1217
1218    This is an iterator not a sequence.  To reverse it use::
1219
1220        reversed(tuple(ordered_dag(dag)))
1221
1222    Raise an error if there are any cycles.
1223
1224    Keys are arbitrary hashable values.
1225    """
1226    # Local import to make the function stand-alone, and easier to borrow
1227    from functools import reduce
1228
1229    dag = dag.copy()
1230
1231    # make leaves depend on the empty set
1232    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1233    dag.update({node: set() for node in leaves})
1234    while True:
1235        leaves = set(node for node, links in dag.items() if not links)
1236        if not leaves:
1237            break
1238        for node in leaves:
1239            yield node
1240        dag = {node: (links-leaves)
1241               for node, links in dag.items() if node not in leaves}
1242    if dag:
1243        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1244                         % ", ".join(str(node) for node in dag.keys()))
1245
1246import re
1247PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n')
1248SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n'
1249PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n')
1250SUBST_STR = r'printf("\g<template>\n")'
1251def translate(functions, constants=None):
1252    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str]
1253    """
1254    Convert a list of functions to a list of C code strings.
1255
1256    Returns list of corresponding code snippets (with trailing lines in
1257    each block) and a list of warnings generated by the translator.
1258
1259    A function is given by the tuple (source, filename, line number).
1260
1261    Global constants are given in a dictionary of {name: value}.  The
1262    constants are used for name space resolution and type inferencing.
1263    Constants are not translated by this code. Instead, call
1264    :func:`define_constant` with name and value, and maybe block_size
1265    if arrays need to be padded to the next block boundary.
1266
1267    Function prototypes are not generated. Use :func:`ordered_dag`
1268    to list the functions in reverse order of dependency before calling
1269    translate. [Maybe a future revision will return the function prototypes
1270    so that a suitable "*.h" file can be generated.
1271    """
1272    snippets = []
1273    warnings = []
1274    for source, fname, lineno in functions:
1275        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\'))
1276        snippets.append(line_directive)
1277        # Replace simple print function calls with printf statements
1278        source = PRINT_ARGS.sub(SUBST_ARGS, source)
1279        source = PRINT_STR.sub(SUBST_STR, source)
1280        tree = ast.parse(source)
1281        generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
1282        generator.visit(tree)
1283        c_code = "".join(generator.c_proc)
1284        snippets.append(c_code)
1285        warnings.extend(generator.warnings)
1286    return snippets, warnings
1287
1288
1289C_HEADER_LINENO = getframeinfo(currentframe()).lineno + 2
1290C_HEADER = """
1291#line %d "%s"
1292#include <stdio.h>
1293#include <stdbool.h>
1294#include <math.h>
1295#define constant const
1296double square(double x) { return x*x; }
1297double cube(double x) { return x*x*x; }
1298double polyval(constant double *coef, double x, int N)
1299{
1300    int i = 0;
1301    double ans = coef[0];
1302
1303    while (i < N) {
1304        ans = ans * x + coef[i++];
1305    }
1306
1307    return ans;
1308}
1309"""
1310
1311USAGE = """\
1312Usage: python py2c.py <infile> [<outfile>]
1313
1314if outfile is omitted, output file is '<infile>.c'
1315"""
1316
1317def main():
1318    import os
1319    #print("Parsing...using Python" + sys.version)
1320    if len(sys.argv) == 1:
1321        print(USAGE)
1322        return
1323
1324    fname_in = sys.argv[1]
1325    if len(sys.argv) == 2:
1326        fname_base = os.path.splitext(fname_in)[0]
1327        fname_out = str(fname_base) + '.c'
1328    else:
1329        fname_out = sys.argv[2]
1330
1331    with open(fname_in, "r") as python_file:
1332        code = python_file.read()
1333    name = "gauss"
1334    code = (code
1335            .replace(name+'.n', 'GAUSS_N')
1336            .replace(name+'.z', 'GAUSS_Z')
1337            .replace(name+'.w', 'GAUSS_W')
1338            .replace('if __name__ == "__main__"', "def main()")
1339           )
1340
1341    translation, warnings = translate([(code, fname_in, 1)])
1342    c_code = "".join(translation)
1343    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])")
1344
1345    with open(fname_out, "w") as file_out:
1346        file_out.write(C_HEADER%(C_HEADER_LINENO, __file__))
1347        file_out.write(c_code)
1348
1349    if warnings:
1350        print("\n".join(warnings))
1351    #print("...Done")
1352
1353if __name__ == "__main__":
1354    main()
Note: See TracBrowser for help on using the repository browser.