source: sasmodels/sasmodels/py2c.py @ c01ed3e

Last change on this file since c01ed3e was c01ed3e, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

code cleanup for py2c converter

  • Property mode set to 100644
File size: 37.6 KB
Line 
1"""
2    py2c
3    ~~~~
4
5    Convert simple numeric python code into C code.
6
7    The translate() function works on
8
9    Variables definition in C
10    -------------------------
11    Defining variables within the translate function is a bit of a guess work,
12    using following rules:
13    *   By default, a variable is a 'double'.
14    *   Variable in a for loop is an int.
15    *   Variable that is references with brackets is an array of doubles. The
16        variable within the brackets is integer. For example, in the
17        reference 'var1[var2]', var1 is a double array, and var2 is an integer.
18    *   Assignment to an argument makes that argument an array, and the index
19        in that assignment is 0.
20        For example, the following python code::
21            def func(arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code::
24            double func(double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the
29        following C code::
30
31            def func(arg1, arg2):          double func(double arg1) {
32                arg2 = 17.                      arg2[0] = 17.0;
33                                            }
34    *   All functions are defined as double, even if there is no
35        return statement.
36
37Based on codegen.py:
38
39    :copyright: Copyright 2008 by Armin Ronacher.
40    :license: BSD.
41"""
42"""
43Update Notes
44============
4511/22/2017, O.E.   Each 'visit_*' method is to build a C statement string. It
46                    shold insert 4 blanks per indentation level.
47                    The 'body' method will combine all the strings, by adding
48                    the 'current_statement' to the c_proc string list
49   11/2017, OE: variables, argument definition implemented.
50   Note: An argument is considered an array if it is the target of an
51        assignment. In that case it is translated to <var>[0]
5211/27/2017, OE: 'pow' basicly working
53  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
54  /12/2017, OE: Power function, including special cases of
55                square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
56                translate_power, called from visit_BinOp
5712/07/2017, OE: Translation of integer division, '\\' in python, implemented
58                in translate_integer_divide, called from visit_BinOp
5912/07/2017, OE: C variable definition handled in 'define_c_vars'
60              : Python integer division, '//', translated to C in
61                'translate_integer_divide'
6212/15/2017, OE: Precedence maintained by writing opening and closing
63                parenthesesm '(',')', in procedure 'visit_BinOp'.
6412/18/2017, OE: Added call to 'add_current_line()' at the beginning
65                of visit_Return
66
67"""
68import ast
69import sys
70from ast import NodeVisitor
71
72BINOP_SYMBOLS = {}
73BINOP_SYMBOLS[ast.Add] = '+'
74BINOP_SYMBOLS[ast.Sub] = '-'
75BINOP_SYMBOLS[ast.Mult] = '*'
76BINOP_SYMBOLS[ast.Div] = '/'
77BINOP_SYMBOLS[ast.Mod] = '%'
78BINOP_SYMBOLS[ast.Pow] = '**'
79BINOP_SYMBOLS[ast.LShift] = '<<'
80BINOP_SYMBOLS[ast.RShift] = '>>'
81BINOP_SYMBOLS[ast.BitOr] = '|'
82BINOP_SYMBOLS[ast.BitXor] = '^'
83BINOP_SYMBOLS[ast.BitAnd] = '&'
84BINOP_SYMBOLS[ast.FloorDiv] = '//'
85
86BOOLOP_SYMBOLS = {}
87BOOLOP_SYMBOLS[ast.And] = '&&'
88BOOLOP_SYMBOLS[ast.Or] = '||'
89
90CMPOP_SYMBOLS = {}
91CMPOP_SYMBOLS[ast.Eq] = '=='
92CMPOP_SYMBOLS[ast.NotEq] = '!='
93CMPOP_SYMBOLS[ast.Lt] = '<'
94CMPOP_SYMBOLS[ast.LtE] = '<='
95CMPOP_SYMBOLS[ast.Gt] = '>'
96CMPOP_SYMBOLS[ast.GtE] = '>='
97CMPOP_SYMBOLS[ast.Is] = 'is'
98CMPOP_SYMBOLS[ast.IsNot] = 'is not'
99CMPOP_SYMBOLS[ast.In] = 'in'
100CMPOP_SYMBOLS[ast.NotIn] = 'not in'
101
102UNARYOP_SYMBOLS = {}
103UNARYOP_SYMBOLS[ast.Invert] = '~'
104UNARYOP_SYMBOLS[ast.Not] = 'not'
105UNARYOP_SYMBOLS[ast.UAdd] = '+'
106UNARYOP_SYMBOLS[ast.USub] = '-'
107
108
109def to_source(tree, constants=None, fname=None, lineno=0):
110    """
111    This function can convert a syntax tree into C sourcecode.
112    """
113    generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
114    generator.visit(tree)
115    c_code = "\n".join(generator.c_proc)
116    return c_code
117
118def isevaluable(s):
119    try:
120        eval(s)
121        return True
122    except Exception:
123        return False
124
125class SourceGenerator(NodeVisitor):
126    """This visitor is able to transform a well formed syntax tree into python
127    sourcecode.  For more details have a look at the docstring of the
128    `node_to_source` function.
129    """
130
131    def __init__(self, indent_with="    ", add_line_information=False,
132                 constants=None, fname=None, lineno=0):
133        self.result = []
134        self.indent_with = indent_with
135        self.add_line_information = add_line_information
136        self.indentation = 0
137        self.new_lines = 0
138
139        # for C
140        self.c_proc = []
141        self.signature_line = 0
142        self.arguments = []
143        self.current_function = ""
144        self.fname = fname
145        self.lineno_offset = lineno
146        self.warnings = []
147        self.statements = []
148        self.current_statement = ""
149        # TODO: use set rather than list for c_vars, ...
150        self.c_vars = []
151        self.c_int_vars = []
152        self.c_pointers = []
153        self.c_dcl_pointers = []
154        self.c_functions = []
155        self.c_vectors = []
156        self.c_constants = constants if constants is not None else {}
157        self.in_subref = False
158        self.in_subscript = False
159        self.tuples = []
160        self.required_functions = []
161        self.is_sequence = False
162        self.visited_args = False
163
164    def write_python(self, x):
165        if self.new_lines:
166            if self.result:
167                self.result.append('\n' * self.new_lines)
168            self.result.append(self.indent_with * self.indentation)
169            self.new_lines = 0
170        self.result.append(x)
171
172    def write_c(self, statement):
173        # TODO: build up as a list rather than adding to string
174        self.current_statement += statement
175
176    def add_c_line(self, line):
177        indentation = self.indent_with * self.indentation
178        self.c_proc.append("".join((indentation, line, "\n")))
179
180    def add_current_line(self):
181        if self.current_statement:
182            self.add_c_line(self.current_statement)
183            self.current_statement = ''
184
185    def add_unique_var(self, new_var):
186        if new_var not in self.c_vars:
187            self.c_vars.append(str(new_var))
188
189    def write_sincos(self, node):
190        angle = str(node.args[0].id)
191        self.write_c(node.args[1].id + " = sin(" + angle + ");")
192        self.add_current_line()
193        self.write_c(node.args[2].id + " = cos(" + angle + ");")
194        self.add_current_line()
195        for arg in node.args:
196            self.add_unique_var(arg.id)
197
198    def newline(self, node=None, extra=0):
199        self.new_lines = max(self.new_lines, 1 + extra)
200        if node is not None and self.add_line_information:
201            self.write_c('# line: %s' % node.lineno)
202            self.new_lines = 1
203        if self.current_statement:
204            self.statements.append(self.current_statement)
205            self.current_statement = ''
206
207    def body(self, statements):
208        if self.current_statement:
209            self.add_current_line()
210        self.new_line = True
211        self.indentation += 1
212        for stmt in statements:
213            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
214            #    target_name = stmt.targets[0].id # target name needed for debug only
215            self.visit(stmt)
216        self.add_current_line() # just for breaking point. to be deleted.
217        self.indentation -= 1
218
219    def body_or_else(self, node):
220        self.body(node.body)
221        if node.orelse:
222            self.newline()
223            self.write_c('else:')
224            self.body(node.orelse)
225
226    def signature(self, node):
227        want_comma = []
228        def write_comma():
229            if want_comma:
230                self.write_c(', ')
231            else:
232                want_comma.append(True)
233
234        # for C
235        for arg in node.args:
236            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
237            try:
238                arg_name = arg.arg
239            except AttributeError:
240                arg_name = arg.id
241            self.arguments.append(arg_name)
242
243        padding = [None] *(len(node.args) - len(node.defaults))
244        for arg, default in zip(node.args, padding + node.defaults):
245            if default is not None:
246                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
247                try:
248                    arg_name = arg.arg
249                except AttributeError:
250                    arg_name = arg.id
251                w_str = ("Default Parameters are unknown to C: '%s = %s"
252                         % arg_name, str(default.n))
253                self.warnings.append(w_str)
254
255    def decorators(self, node):
256        for decorator in node.decorator_list:
257            self.newline(decorator)
258            self.write_python('@')
259            self.visit(decorator)
260
261    # Statements
262
263    def visit_Assert(self, node):
264        self.newline(node)
265        self.write_c('assert ')
266        self.visit(node.test)
267        if node.msg is not None:
268            self.write_python(', ')
269            self.visit(node.msg)
270
271    def define_c_vars(self, target):
272        if hasattr(target, 'id'):
273        # a variable is considered an array if it apears in the agrument list
274        # and being assigned to. For example, the variable p in the following
275        # sniplet is a pointer, while q is not
276        # def somefunc(p, q):
277        #  p = q + 1
278        #  return
279        #
280            if target.id not in self.c_vars:
281                if target.id in self.arguments:
282                    idx = self.arguments.index(target.id)
283                    new_target = self.arguments[idx] + "[0]"
284                    if new_target not in self.c_pointers:
285                        target.id = new_target
286                        self.c_pointers.append(self.arguments[idx])
287                else:
288                    self.c_vars.append(target.id)
289
290    def add_semi_colon(self):
291        semi_pos = self.current_statement.find(';')
292        if semi_pos > 0.0:
293            self.current_statement = self.current_statement.replace(';', '')
294        self.write_c(';')
295
296    def visit_Assign(self, node):
297        self.add_current_line()
298        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
299            if idx:
300                self.write_c(' = ')
301            self.define_c_vars(target)
302            self.visit(target)
303        # Capture assigned tuple names, if any
304        targets = self.tuples[:]
305        del self.tuples[:]
306        self.write_c(' = ')
307        self.is_sequence = False
308        self.visited_args = False
309        self.visit(node.value)
310        self.add_semi_colon()
311        self.add_current_line()
312        # Assign tuples to tuples, if any
313        # TODO: doesn't handle swap:  a,b = b,a
314        for target, item in zip(targets, self.tuples):
315            self.visit(target)
316            self.write_c(' = ')
317            self.visit(item)
318            self.add_semi_colon()
319            self.add_current_line()
320        if self.is_sequence and not self.visited_args:
321            for target in node.targets:
322                if hasattr(target, 'id'):
323                    if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
324                        if target.id not in self.c_dcl_pointers:
325                            self.c_dcl_pointers.append(target.id)
326                            if target.id in self.c_vars:
327                                self.c_vars.remove(target.id)
328        self.current_statement = ''
329
330    def visit_AugAssign(self, node):
331        if node.target.id not in self.c_vars:
332            if node.target.id not in self.arguments:
333                self.c_vars.append(node.target.id)
334        self.visit(node.target)
335        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
336        self.visit(node.value)
337        self.add_semi_colon()
338        self.add_current_line()
339
340    def visit_ImportFrom(self, node):
341        return  # import ignored
342        self.newline(node)
343        self.write_python('from %s%s import ' %('.' * node.level, node.module))
344        for idx, item in enumerate(node.names):
345            if idx:
346                self.write_python(', ')
347            self.write_python(item)
348
349    def visit_Import(self, node):
350        return  # import ignored
351        self.newline(node)
352        for item in node.names:
353            self.write_python('import ')
354            self.visit(item)
355
356    def visit_Expr(self, node):
357        self.newline(node)
358        self.generic_visit(node)
359
360    def write_c_pointers(self, start_var):
361        if self.c_dcl_pointers:
362            var_list = []
363            for c_ptr in self.c_dcl_pointers:
364                if c_ptr not in self.arguments:
365                    var_list.append("*" + c_ptr)
366                if c_ptr in self.c_vars:
367                    self.c_vars.remove(c_ptr)
368            if var_list:
369                c_dcl = "    double " + ", ".join(var_list) + ";\n"
370                self.c_proc.insert(start_var, c_dcl)
371                start_var += 1
372        return start_var
373
374    def insert_c_vars(self, start_var):
375        have_decls = False
376        start_var = self.write_c_pointers(start_var)
377        if self.c_int_vars:
378            for var in self.c_int_vars:
379                if var in self.c_vars:
380                    self.c_vars.remove(var)
381            decls = ", ".join(self.c_int_vars)
382            self.c_proc.insert(start_var, "    int " + decls + ";\n")
383            have_decls = True
384            start_var += 1
385
386        if self.c_vars:
387            decls = ", ".join(self.c_vars)
388            self.c_proc.insert(start_var, "    double " + decls + ";\n")
389            have_decls = True
390            start_var += 1
391
392        if self.c_vectors:
393            for vec_number, vec_value  in enumerate(self.c_vectors):
394                name = "vec" + str(vec_number + 1)
395                decl = "    double " + name + "[] = {" + vec_value + "};"
396                self.c_proc.insert(start_var, decl + "\n")
397                start_var += 1
398
399        del self.c_vars[:]
400        del self.c_int_vars[:]
401        del self.c_vectors[:]
402        del self.c_pointers[:]
403        del self.c_dcl_pointers[:]
404        if have_decls:
405            self.c_proc.insert(start_var, "\n")
406
407    def insert_signature(self):
408        arg_decls = []
409        for arg in self.arguments:
410            decl = "double " + arg
411            if arg in self.c_pointers:
412                decl += "[]"
413            arg_decls.append(decl)
414        args_str = ", ".join(arg_decls)
415        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
416        if self.signature_line >= 0:
417            self.c_proc.insert(self.signature_line, method_sig)
418
419    def visit_FunctionDef(self, node):
420        if self.current_function:
421            self.unsupported(node, "function within a function")
422        self.current_function = node.name
423
424        self.newline(extra=1)
425        self.decorators(node)
426        self.newline(node)
427        self.arguments = []
428        self.visit(node.args)
429        # for C
430        self.signature_line = len(self.c_proc)
431        self.add_c_line("\n{")
432        start_vars = len(self.c_proc) + 1
433        self.body(node.body)
434        self.add_c_line("}\n")
435        self.insert_signature()
436        self.insert_c_vars(start_vars)
437        self.c_pointers = []
438        self.current_function = ""
439
440    def visit_ClassDef(self, node):
441        have_args = []
442        def paren_or_comma():
443            if have_args:
444                self.write_python(', ')
445            else:
446                have_args.append(True)
447                self.write_python('(')
448
449        self.newline(extra=2)
450        self.decorators(node)
451        self.newline(node)
452        self.write_python('class %s' % node.name)
453        for base in node.bases:
454            paren_or_comma()
455            self.visit(base)
456        # CRUFT: python 2.6 does not have "keywords" attribute
457        if hasattr(node, 'keywords'):
458            for keyword in node.keywords:
459                paren_or_comma()
460                self.write_python(keyword.arg + '=')
461                self.visit(keyword.value)
462            if node.starargs is not None:
463                paren_or_comma()
464                self.write_python('*')
465                self.visit(node.starargs)
466            if node.kwargs is not None:
467                paren_or_comma()
468                self.write_python('**')
469                self.visit(node.kwargs)
470        self.write_python(have_args and '):' or ':')
471        self.body(node.body)
472
473    def visit_If(self, node):
474        self.write_c('if ')
475        self.visit(node.test)
476        self.write_c(' {')
477        self.body(node.body)
478        self.add_c_line('}')
479        while True:
480            else_ = node.orelse
481            if len(else_) == 0:
482                break
483            #elif hasattr(else_, 'orelse'):
484            elif len(else_) == 1 and isinstance(else_[0], ast.If):
485                node = else_[0]
486                #self.newline()
487                self.write_c('else if ')
488                self.visit(node.test)
489                self.write_c(' {')
490                self.body(node.body)
491                self.add_current_line()
492                self.add_c_line('}')
493                #break
494            else:
495                self.newline()
496                self.write_c('else {')
497                self.body(node.body)
498                self.add_c_line('}')
499                break
500
501    def get_for_range(self, node):
502        stop = ""
503        start = '0'
504        step = '1'
505        for_args = []
506        temp_statement = self.current_statement
507        self.current_statement = ''
508        for arg in node.iter.args:
509            self.visit(arg)
510            for_args.append(self.current_statement)
511            self.current_statement = ''
512        self.current_statement = temp_statement
513        if len(for_args) == 1:
514            stop = for_args[0]
515        elif len(for_args) == 2:
516            start = for_args[0]
517            stop = for_args[1]
518        elif len(for_args) == 3:
519            start = for_args[0]
520            stop = for_args[1]
521            start = for_args[2]
522        else:
523            raise("Ilegal for loop parameters")
524        return start, stop, step
525
526    def visit_For(self, node):
527        # node: for iterator is stored in node.target.
528        # Iterator name is in node.target.id.
529        self.add_current_line()
530        fForDone = False
531        self.current_statement = ''
532        if hasattr(node.iter, 'func'):
533            if hasattr(node.iter.func, 'id'):
534                if node.iter.func.id == 'range':
535                    self.visit(node.target)
536                    iterator = self.current_statement
537                    self.current_statement = ''
538                    if iterator not in self.c_int_vars:
539                        self.c_int_vars.append(iterator)
540                    start, stop, step = self.get_for_range(node)
541                    self.write_c("for(" + iterator + "=" + str(start) +
542                                 " ; " + iterator + " < " + str(stop) +
543                                 " ; " + iterator + " += " + str(step) + ") {")
544                    self.body_or_else(node)
545                    self.write_c("}")
546                    fForDone = True
547        if not fForDone:
548            # Generate the statement that is causing the error
549            self.current_statement = ''
550            self.write_c('for ')
551            self.visit(node.target)
552            self.write_c(' in ')
553            self.visit(node.iter)
554            self.write_c(':')
555            # report the error
556            self.unsupported("unsupported " + self.current_statement)
557
558    def visit_While(self, node):
559        self.newline(node)
560        self.write_c('while ')
561        self.visit(node.test)
562        self.write_c(':')
563        self.body_or_else(node)
564
565    def visit_With(self, node):
566        self.unsupported(node)
567        self.newline(node)
568        self.write_python('with ')
569        self.visit(node.context_expr)
570        if node.optional_vars is not None:
571            self.write_python(' as ')
572            self.visit(node.optional_vars)
573        self.write_python(':')
574        self.body(node.body)
575
576    def visit_Pass(self, node):
577        self.newline(node)
578        #self.write_python('pass')
579
580    def visit_Print(self, node):
581        # TODO: print support would be nice, though hard to do
582        self.unsupported(node)
583        # CRUFT: python 2.6 only
584        self.newline(node)
585        self.write_c('print ')
586        want_comma = False
587        if node.dest is not None:
588            self.write_c(' >> ')
589            self.visit(node.dest)
590            want_comma = True
591        for value in node.values:
592            if want_comma:
593                self.write_c(', ')
594            self.visit(value)
595            want_comma = True
596        if not node.nl:
597            self.write_c(',')
598
599    def visit_Delete(self, node):
600        self.unsupported(node)
601        self.newline(node)
602        self.write_python('del ')
603        for idx, target in enumerate(node):
604            if idx:
605                self.write_python(', ')
606            self.visit(target)
607
608    def visit_TryExcept(self, node):
609        self.unsupported(node)
610        self.newline(node)
611        self.write_python('try:')
612        self.body(node.body)
613        for handler in node.handlers:
614            self.visit(handler)
615
616    def visit_TryFinally(self, node):
617        self.unsupported(node)
618        self.newline(node)
619        self.write_python('try:')
620        self.body(node.body)
621        self.newline(node)
622        self.write_python('finally:')
623        self.body(node.finalbody)
624
625    def visit_Global(self, node):
626        self.unsupported(node)
627        self.newline(node)
628        self.write_python('global ' + ', '.join(node.names))
629
630    def visit_Nonlocal(self, node):
631        self.newline(node)
632        self.write_python('nonlocal ' + ', '.join(node.names))
633
634    def visit_Return(self, node):
635        self.add_current_line()
636        if node.value is None:
637            self.write_c('return')
638        else:
639            self.write_c('return(')
640            self.visit(node.value)
641        self.write_c(')')
642        self.add_semi_colon()
643        self.add_c_line(self.current_statement)
644        self.current_statement = ''
645
646    def visit_Break(self, node):
647        self.newline(node)
648        self.write_c('break')
649
650    def visit_Continue(self, node):
651        self.newline(node)
652        self.write_c('continue')
653
654    def visit_Raise(self, node):
655        self.unsupported(node)
656        # CRUFT: Python 2.6 / 3.0 compatibility
657        self.newline(node)
658        self.write_python('raise')
659        if hasattr(node, 'exc') and node.exc is not None:
660            self.write_python(' ')
661            self.visit(node.exc)
662            if node.cause is not None:
663                self.write_python(' from ')
664                self.visit(node.cause)
665        elif hasattr(node, 'type') and node.type is not None:
666            self.visit(node.type)
667            if node.inst is not None:
668                self.write_python(', ')
669                self.visit(node.inst)
670            if node.tback is not None:
671                self.write_python(', ')
672                self.visit(node.tback)
673
674    # Expressions
675
676    def visit_Attribute(self, node):
677        self.unsupported(node, "attribute reference a.b not supported")
678        self.visit(node.value)
679        self.write_python('.' + node.attr)
680
681    def visit_Call(self, node):
682        want_comma = []
683        def write_comma():
684            if want_comma:
685                self.write_c(', ')
686            else:
687                want_comma.append(True)
688        if hasattr(node.func, 'id'):
689            if node.func.id not in self.c_functions:
690                self.c_functions.append(node.func.id)
691            if node.func.id == 'abs':
692                self.write_c("fabs ")
693            elif node.func.id == 'int':
694                self.write_c('(int) ')
695            elif node.func.id == "SINCOS":
696                self.write_sincos(node)
697                return
698            else:
699                self.visit(node.func)
700        else:
701            self.visit(node.func)
702        self.write_c('(')
703        for arg in node.args:
704            write_comma()
705            self.visited_args = True
706            self.visit(arg)
707        for keyword in node.keywords:
708            write_comma()
709            self.write_c(keyword.arg + '=')
710            self.visit(keyword.value)
711        if hasattr(node, 'starargs'):
712            if node.starargs is not None:
713                write_comma()
714                self.write_c('*')
715                self.visit(node.starargs)
716        if hasattr(node, 'kwargs'):
717            if node.kwargs is not None:
718                write_comma()
719                self.write_c('**')
720                self.visit(node.kwargs)
721        self.write_c(');')
722
723    def visit_Name(self, node):
724        self.write_c(node.id)
725        if node.id in self.c_pointers and not self.in_subref:
726            self.write_c("[0]")
727        name = ""
728        sub = node.id.find("[")
729        if sub > 0:
730            name = node.id[0:sub].strip()
731        else:
732            name = node.id
733        # add variable to c_vars if it ins't there yet, not an argument and not a number
734        if (name not in self.c_functions and name not in self.c_vars and
735                name not in self.c_int_vars and name not in self.arguments and
736                name not in self.c_constants and not name.isdigit()):
737            if self.in_subscript:
738                self.c_int_vars.append(node.id)
739            else:
740                self.c_vars.append(node.id)
741
742    def visit_Str(self, node):
743        self.write_c(repr(node.s))
744
745    def visit_Bytes(self, node):
746        self.write_c(repr(node.s))
747
748    def visit_Num(self, node):
749        self.write_c(repr(node.n))
750
751    def visit_Tuple(self, node):
752        for idx, item in enumerate(node.elts):
753            if idx:
754                self.tuples.append(item)
755            else:
756                self.visit(item)
757
758    def sequence_visit(left, right):
759        def visit(self, node):
760            self.is_sequence = True
761            s = ""
762            for idx, item in enumerate(node.elts):
763                if idx > 0 and s:
764                    s += ', '
765                if hasattr(item, 'id'):
766                    s += item.id
767                elif hasattr(item, 'n'):
768                    s += str(item.n)
769            if s:
770                self.c_vectors.append(s)
771                vec_name = "vec"  + str(len(self.c_vectors))
772                self.write_c(vec_name)
773        return visit
774
775    visit_List = sequence_visit('[', ']')
776    visit_Set = sequence_visit('{', '}')
777    del sequence_visit
778
779    def visit_Dict(self, node):
780        self.unsupported(node)
781        self.write_python('{')
782        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
783            if idx:
784                self.write_python(', ')
785            self.visit(key)
786            self.write_python(': ')
787            self.visit(value)
788        self.write_python('}')
789
790    def get_special_power(self, string):
791        function_name = ''
792        is_negative_exp = False
793        if isevaluable(str(self.current_statement)):
794            exponent = eval(string)
795            is_negative_exp = exponent < 0
796            abs_exponent = abs(exponent)
797            if abs_exponent == 2:
798                function_name = "square"
799            elif abs_exponent == 3:
800                function_name = "cube"
801            elif abs_exponent == 0.5:
802                function_name = "sqrt"
803            elif abs_exponent == 1.0/3.0:
804                function_name = "cbrt"
805        if function_name == '':
806            function_name = "pow"
807        return function_name, is_negative_exp
808
809    def translate_power(self, node):
810        # get exponent by visiting the right hand argument.
811        function_name = "pow"
812        temp_statement = self.current_statement
813        # 'visit' functions write the results to the 'current_statement' class memnber
814        # Here, a temporary variable, 'temp_statement', is used, that enables the
815        # use of the 'visit' function
816        self.current_statement = ''
817        self.visit(node.right)
818        exponent = self.current_statement.replace(' ', '')
819        function_name, is_negative_exp = self.get_special_power(self.current_statement)
820        self.current_statement = temp_statement
821        if is_negative_exp:
822            self.write_c("1.0 /(")
823        self.write_c(function_name + "(")
824        self.visit(node.left)
825        if function_name == "pow":
826            self.write_c(", ")
827            self.visit(node.right)
828        self.write_c(")")
829        if is_negative_exp:
830            self.write_c(")")
831        self.write_c(" ")
832
833    def translate_integer_divide(self, node):
834        self.write_c("(int)(")
835        self.visit(node.left)
836        self.write_c(") /(int)(")
837        self.visit(node.right)
838        self.write_c(")")
839
840    def visit_BinOp(self, node):
841        self.write_c("(")
842        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
843            self.translate_power(node)
844        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
845            self.translate_integer_divide(node)
846        else:
847            self.visit(node.left)
848            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
849            self.visit(node.right)
850        self.write_c(")")
851
852    # for C
853    def visit_BoolOp(self, node):
854        self.write_c('(')
855        for idx, value in enumerate(node.values):
856            if idx:
857                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
858            self.visit(value)
859        self.write_c(')')
860
861    def visit_Compare(self, node):
862        self.write_c('(')
863        self.visit(node.left)
864        for op, right in zip(node.ops, node.comparators):
865            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
866            self.visit(right)
867        self.write_c(')')
868
869    def visit_UnaryOp(self, node):
870        self.write_c('(')
871        op = UNARYOP_SYMBOLS[type(node.op)]
872        self.write_c(op)
873        if op == 'not':
874            self.write_c(' ')
875        self.visit(node.operand)
876        self.write_c(')')
877
878    def visit_Subscript(self, node):
879        if node.value.id not in self.c_constants:
880            if node.value.id not in self.c_pointers:
881                self.c_pointers.append(node.value.id)
882        self.in_subref = True
883        self.visit(node.value)
884        self.in_subref = False
885        self.write_c('[')
886        self.in_subscript = True
887        self.visit(node.slice)
888        self.in_subscript = False
889        self.write_c(']')
890
891    def visit_Slice(self, node):
892        if node.lower is not None:
893            self.visit(node.lower)
894        self.write_python(':')
895        if node.upper is not None:
896            self.visit(node.upper)
897        if node.step is not None:
898            self.write_python(':')
899            if not(isinstance(node.step, Name) and node.step.id == 'None'):
900                self.visit(node.step)
901
902    def visit_ExtSlice(self, node):
903        for idx, item in node.dims:
904            if idx:
905                self.write_python(', ')
906            self.visit(item)
907
908    def visit_Yield(self, node):
909        self.unsupported(node)
910        self.write_python('yield ')
911        self.visit(node.value)
912
913    def visit_Lambda(self, node):
914        self.unsupported(node)
915        self.write_python('lambda ')
916        self.visit(node.args)
917        self.write_python(': ')
918        self.visit(node.body)
919
920    def visit_Ellipsis(self, node):
921        self.unsupported(node)
922        self.write_python('Ellipsis')
923
924    def generator_visit(left, right):
925        def visit(self, node):
926            self.write_python(left)
927            self.write_c(left)
928            self.visit(node.elt)
929            for comprehension in node.generators:
930                self.visit(comprehension)
931            self.write_c(right)
932            #self.write_python(right)
933        return visit
934
935    visit_ListComp = generator_visit('[', ']')
936    visit_GeneratorExp = generator_visit('(', ')')
937    visit_SetComp = generator_visit('{', '}')
938    del generator_visit
939
940    def visit_DictComp(self, node):
941        self.unsupported(node)
942        self.write_python('{')
943        self.visit(node.key)
944        self.write_python(': ')
945        self.visit(node.value)
946        for comprehension in node.generators:
947            self.visit(comprehension)
948        self.write_python('}')
949
950    def visit_IfExp(self, node):
951        self.visit(node.body)
952        self.write_c(' if ')
953        self.visit(node.test)
954        self.write_c(' else ')
955        self.visit(node.orelse)
956
957    def visit_Starred(self, node):
958        self.write_c('*')
959        self.visit(node.value)
960
961    def visit_Repr(self, node):
962        # CRUFT: python 2.6 only
963        self.write_c('`')
964        self.visit(node.value)
965        self.write_python('`')
966
967    # Helper Nodes
968
969    def visit_alias(self, node):
970        self.unsupported(node)
971        self.write_python(node.name)
972        if node.asname is not None:
973            self.write_python(' as ' + node.asname)
974
975    def visit_comprehension(self, node):
976        self.write_c(' for ')
977        self.visit(node.target)
978        self.write_C(' in ')
979        #self.write_python(' in ')
980        self.visit(node.iter)
981        if node.ifs:
982            for if_ in node.ifs:
983                self.write_python(' if ')
984                self.visit(if_)
985
986    def visit_arguments(self, node):
987        self.signature(node)
988
989    def unsupported(self, node, message=None):
990        if hasattr(node, "value"):
991            lineno = node.value.lineno
992        elif hasattr(node, "iter"):
993            lineno = node.iter.lineno
994        else:
995            #print(dir(node))
996            lineno = 0
997
998        lineno += self.lineno_offset
999        if self.fname:
1000            location = "%s(%d)" % (self.fname, lineno)
1001        else:
1002            location = "%d" % (self.fname, lineno)
1003        if self.current_function:
1004            location += ", function %s" % self.current_function
1005        if message is None:
1006            message = node.__class__.__name__ + " syntax not supported"
1007        raise SyntaxError("[%s] %s" % (location, message))
1008
1009def print_function(f=None):
1010    """
1011    Print out the code for the function
1012    """
1013    # Include some comments to see if they get printed
1014    import ast
1015    import inspect
1016    if f is not None:
1017        tree = ast.parse(inspect.getsource(f))
1018        tree_source = to_source(tree)
1019        print(tree_source)
1020
1021def define_constant(name, value, block_size=1):
1022    # type: (str, any, int) -> str
1023    """
1024    Convert a python constant into a C constant of the same name.
1025
1026    Returns the C declaration of the constant as a string, possibly containing
1027    line feeds.  The string will not be indented.
1028
1029    Supports int, double and sequences of double.
1030    """
1031    const = "constant "  # OpenCL needs globals to be constant
1032    if isinstance(value, int):
1033        parts = [const + "int ", name, " = ", "%d"%value, ";"]
1034    elif isinstance(value, float):
1035        parts = [const + "double ", name, " = ", "%.15g"%value, ";"]
1036    else:
1037        try:
1038            len(value)
1039        except TypeError:
1040            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1041        # extend constant arrays to a multiple of 4; not sure if this
1042        # is necessary, but some OpenCL targets broke if the number
1043        # of parameters in the parameter table was not a multiple of 4,
1044        # so do it for all constant arrays to be safe.
1045        if len(value)%block_size != 0:
1046            value = list(value) + [0.]*(block_size - len(value)%block_size)
1047        elements = ["%.15g"%v for v in value]
1048        parts = [const + "double ", name, "[]", " = ",
1049                 "{\n   ", ", ".join(elements), "\n};"]
1050
1051    return "".join(parts)
1052
1053
1054# Modified from the following:
1055#
1056#    http://code.activestate.com/recipes/578272-topological-sort/
1057#    Copyright (C) 2012 Sam Denton
1058#    License: MIT
1059def ordered_dag(dag):
1060    # type: (Dict[T, Set[T]]) -> Iterator[T]
1061    """
1062    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1063    in order such that every key occurs after the keys it depends upon.
1064
1065    This is an iterator not a sequence.  To reverse it use::
1066
1067        reversed(tuple(ordered_dag(dag)))
1068
1069    Raise an error if there are any cycles.
1070
1071    Keys are arbitrary hashable values.
1072    """
1073    # Local import to make the function stand-alone, and easier to borrow
1074    from functools import reduce
1075
1076    dag = dag.copy()
1077
1078    # make leaves depend on the empty set
1079    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1080    dag.update({node: set() for node in leaves})
1081    while True:
1082        leaves = set(node for node, links in dag.items() if not links)
1083        if not leaves:
1084            break
1085        for node in leaves:
1086            yield node
1087        dag = {node: (links-leaves)
1088               for node, links in dag.items() if node not in leaves}
1089    if dag:
1090        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1091                         % ", ".join(str(node) for node in dag.keys()))
1092
1093def translate(functions, constants=None):
1094    # type: (List[(str, str, int)], Dict[str, any]) -> List[str]
1095    """
1096    Convert a set of functions
1097    """
1098    snippets = []
1099    #snippets.append("#include <math.h>")
1100    #snippets.append("")
1101    for source, fname, lineno in functions:
1102        line_directive = '#line %d "%s"'%(lineno, fname.replace('\\', '\\\\'))
1103        snippets.append(line_directive)
1104        tree = ast.parse(source)
1105        c_code = to_source(tree, constants=constants, fname=fname, lineno=lineno)
1106        snippets.append(c_code)
1107    return snippets
1108
1109def main():
1110    import os
1111    print("Parsing...using Python" + sys.version)
1112    if len(sys.argv) == 1:
1113        print("""\
1114Usage: python py2c.py <infile> [<outfile>]
1115
1116if outfile is omitted, output file is '<infile>.c'
1117""")
1118        return
1119
1120    fname_in = sys.argv[1]
1121    if len(sys.argv) == 2:
1122        fname_base = os.path.splitext(fname_in)[0]
1123        fname_out = str(fname_base) + '.c'
1124    else:
1125        fname_out = sys.argv[2]
1126
1127    with open(fname_in, "r") as python_file:
1128        code = python_file.read()
1129    name = "gauss"
1130    code = (code
1131            .replace(name+'.n', 'GAUSS_N')
1132            .replace(name+'.z', 'GAUSS_Z')
1133            .replace(name+'.w', 'GAUSS_W'))
1134
1135    translation = translate([(code, fname_in, 1)])[0]
1136
1137    with open(fname_out, "w") as file_out:
1138        file_out.write(translation)
1139    print("...Done")
1140
1141if __name__ == "__main__":
1142    main()
Note: See TracBrowser for help on using the repository browser.