source: sasmodels/sasmodels/py2c.py @ 4339764

Last change on this file since 4339764 was 4339764, checked in by Omer Eisenberg <omereis@…>, 6 years ago

printing warnings

  • Property mode set to 100644
File size: 44.4 KB
Line 
1"""
2    py2c
3    ~~~~
4
5    Convert simple numeric python code into C code.
6
7    The translate() function works on
8
9    Variables definition in C
10    -------------------------
11    Defining variables within the translate function is a bit of a guess work,
12    using following rules:
13    *   By default, a variable is a 'double'.
14    *   Variable in a for loop is an int.
15    *   Variable that is references with brackets is an array of doubles. The
16        variable within the brackets is integer. For example, in the
17        reference 'var1[var2]', var1 is a double array, and var2 is an integer.
18    *   Assignment to an argument makes that argument an array, and the index
19        in that assignment is 0.
20        For example, the following python code::
21            def func(arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code::
24            double func(double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the
29        following C code::
30
31            def func(arg1, arg2):          double func(double arg1) {
32                arg2 = 17.                      arg2[0] = 17.0;
33                                            }
34    *   All functions are defined as double, even if there is no
35        return statement.
36
37Based on codegen.py:
38
39    :copyright: Copyright 2008 by Armin Ronacher.
40    :license: BSD.
41"""
42"""
43Update Notes
44============
4511/22/2017, O.E.   Each 'visit_*' method is to build a C statement string. It
46                    shold insert 4 blanks per indentation level.
47                    The 'body' method will combine all the strings, by adding
48                    the 'current_statement' to the c_proc string list
49   11/2017, OE: variables, argument definition implemented.
50   Note: An argument is considered an array if it is the target of an
51        assignment. In that case it is translated to <var>[0]
5211/27/2017, OE: 'pow' basicly working
53  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
54  /12/2017, OE: Power function, including special cases of
55                square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
56                translate_power, called from visit_BinOp
5712/07/2017, OE: Translation of integer division, '\\' in python, implemented
58                in translate_integer_divide, called from visit_BinOp
5912/07/2017, OE: C variable definition handled in 'define_c_vars'
60              : Python integer division, '//', translated to C in
61                'translate_integer_divide'
6212/15/2017, OE: Precedence maintained by writing opening and closing
63                parenthesesm '(',')', in procedure 'visit_BinOp'.
6412/18/2017, OE: Added call to 'add_current_line()' at the beginning
65                of visit_Return
6601/4/2018, O.E. Added functions 'get_files_names()', , 'print_usage()'. The get_files_names()
67                function retrieves the in/out file names form the command line, and returns
68                true/false if the number of parameters is valid. In case of no input
69                parameters, usage is prompt and the program terminates.
7001/4/2018, O.E. 'translate(functions, constants=None)' returns string, instaed of a list
7101/04/2017 O.E. Fixed bug in 'visit_If': visiting node.orelse in case else exists.
7201/05/2018 O.E. New class C_Vector defined. It holds name, type, size and initial values.
73                Assignment to that class in 'sequence_visit'.
74                Reading and inserting declaration in 'insert_c_vars'
75
76
77"""
78import ast
79import sys
80from ast import NodeVisitor
81
82BINOP_SYMBOLS = {}
83BINOP_SYMBOLS[ast.Add] = '+'
84BINOP_SYMBOLS[ast.Sub] = '-'
85BINOP_SYMBOLS[ast.Mult] = '*'
86BINOP_SYMBOLS[ast.Div] = '/'
87BINOP_SYMBOLS[ast.Mod] = '%'
88BINOP_SYMBOLS[ast.Pow] = '**'
89BINOP_SYMBOLS[ast.LShift] = '<<'
90BINOP_SYMBOLS[ast.RShift] = '>>'
91BINOP_SYMBOLS[ast.BitOr] = '|'
92BINOP_SYMBOLS[ast.BitXor] = '^'
93BINOP_SYMBOLS[ast.BitAnd] = '&'
94BINOP_SYMBOLS[ast.FloorDiv] = '//'
95
96BOOLOP_SYMBOLS = {}
97BOOLOP_SYMBOLS[ast.And] = '&&'
98BOOLOP_SYMBOLS[ast.Or] = '||'
99
100CMPOP_SYMBOLS = {}
101CMPOP_SYMBOLS[ast.Eq] = '=='
102CMPOP_SYMBOLS[ast.NotEq] = '!='
103CMPOP_SYMBOLS[ast.Lt] = '<'
104CMPOP_SYMBOLS[ast.LtE] = '<='
105CMPOP_SYMBOLS[ast.Gt] = '>'
106CMPOP_SYMBOLS[ast.GtE] = '>='
107CMPOP_SYMBOLS[ast.Is] = 'is'
108CMPOP_SYMBOLS[ast.IsNot] = 'is not'
109CMPOP_SYMBOLS[ast.In] = 'in'
110CMPOP_SYMBOLS[ast.NotIn] = 'not in'
111
112UNARYOP_SYMBOLS = {}
113UNARYOP_SYMBOLS[ast.Invert] = '~'
114UNARYOP_SYMBOLS[ast.Not] = 'not'
115UNARYOP_SYMBOLS[ast.UAdd] = '+'
116UNARYOP_SYMBOLS[ast.USub] = '-'
117
118
119def to_source(tree, constants=None, fname=None, lineno=0):
120    """
121    This function can convert a syntax tree into C sourcecode.
122    """
123    generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno)
124    generator.visit(tree)
125    c_code = "\n".join(generator.c_proc)
126    c_code = "".join(generator.c_proc)
127    return (c_code, generator.warnings)
128
129def isevaluable(s):
130    try:
131        eval(s)
132        return True
133    except Exception:
134        return False
135
136class C_Vector():
137    def __init__(self):
138        self.size = 0
139        self.type = "double"
140        self.values = []
141        self.assign_line = -1
142        self.name = ""
143
144    def get_leading_space(self, indent_with="    ", indent_level = 1):
145        leading_space = ""
146        for n in range(indent_level):
147            leading_space += indent_with
148        return leading_space
149       
150    def item_declare_string(self, indent_with="    ", indent_level = 1):
151        declare_string = self.name + "[" + str(len(self.values)) + "]"
152        return declare_string
153   
154    def declare_string(self, indent_with="    ", indent_level = 1):
155        declare_string = self.get_leading_space(indent_with, indent_level)
156        declare_string += self.type + " " + item_declare_string(indent_with, indent_level)
157#        self.name + "[" + str(len(self.values)) + "];\n"
158        return declare_string
159   
160    def get_assign(self, indent_with="    ", indent_level = 1):
161        assign_loop = []
162        c_string = self.get_leading_space(indent_with, indent_level - 1)
163        c_string += "for (n=0 ; n < " + str(len(self.values)) + " ; n++) {"
164        assign_loop.append(c_string)
165        for n,value in enumerate(self.values):
166            c_string = self.get_leading_space(indent_with, indent_level)
167            c_string += self.name + "[" + str(n) + "] = " + str(value) + ";"
168            assign_loop.append(c_string)
169        c_string = self.get_leading_space(indent_with, indent_level - 1) + "}"
170        assign_loop.append(c_string)
171        return (assign_loop)
172
173class SourceGenerator(NodeVisitor):
174    """This visitor is able to transform a well formed syntax tree into python
175    sourcecode.  For more details have a look at the docstring of the
176    `node_to_source` function.
177    """
178
179    def __init__(self, indent_with="    ", add_line_information=False,
180                 constants=None, fname=None, lineno=0):
181        self.result = []
182        self.indent_with = indent_with
183        self.add_line_information = add_line_information
184        self.indentation = 0
185        self.new_lines = 0
186
187        # for C
188        self.c_proc = []
189        self.signature_line = 0
190        self.arguments = []
191        self.current_function = ""
192        self.fname = fname
193        self.lineno_offset = lineno
194        self.warnings = []
195        self.statements = []
196        self.current_statement = ""
197        # TODO: use set rather than list for c_vars, ...
198        self.c_vars = []
199        self.c_int_vars = []
200        self.c_pointers = []
201        self.c_dcl_pointers = []
202        self.c_functions = []
203        self.c_vectors = []
204        self.c_constants = constants if constants is not None else {}
205        self.in_subref = False
206        self.in_subscript = False
207        self.tuples = []
208        self.required_functions = []
209        self.is_sequence = False
210        self.visited_args = False
211        self.inside_if = False
212        self.C_Vectors = []
213        self.assign_target = ""
214        self.assign_c_vector = []
215
216    def write_python(self, x):
217        if self.new_lines:
218            if self.result:
219                self.result.append('\n' * self.new_lines)
220            self.result.append(self.indent_with * self.indentation)
221            self.new_lines = 0
222        self.result.append(x)
223
224    def write_c(self, statement):
225        # TODO: build up as a list rather than adding to string
226        self.current_statement += statement
227
228    def add_c_line(self, line):
229        indentation = self.indent_with * self.indentation
230        self.c_proc.append("".join((indentation, line, "\n")))
231
232    def add_current_line(self):
233        if self.current_statement:
234            self.add_c_line(self.current_statement)
235            self.current_statement = ''
236
237    def add_unique_var(self, new_var):
238        if new_var not in self.c_vars:
239            self.c_vars.append(str(new_var))
240
241    def write_sincos(self, node):
242        angle = str(node.args[0].id)
243        self.write_c(node.args[1].id + " = sin(" + angle + ");")
244        self.add_current_line()
245        self.write_c(node.args[2].id + " = cos(" + angle + ");")
246        self.add_current_line()
247        for arg in node.args:
248            self.add_unique_var(arg.id)
249
250    def write_print(self, node):
251        self.write_c ("printf (")
252        c_str = self.current_statement
253        self.current_statement = ""
254        self.visit(node.args[0])
255        print_params = self.current_statement # save print agruments
256        self.current_statement = c_str
257        print_params = print_params[1:(len(print_params) - 1)] # remove parentheses
258        prcnt = print_params.rfind('%')
259        if prcnt >= 0:
260            format_string = print_params[0:prcnt-1].strip()
261            format_string = format_string[1:len(format_string) - 1]
262            var_string = print_params[prcnt+1:len(print_params)].strip()
263        else:
264            format_string = print_params
265            var_string = ""
266        format_string= format_string[:len(format_string)]
267        self.write_c('"' + format_string + '\\n"')
268        if (len(var_string) > 0):
269            self.write_c(', ' + var_string)
270        self.write_c(");")
271        self.add_current_line()
272
273
274
275    def newline(self, node=None, extra=0):
276        self.new_lines = max(self.new_lines, 1 + extra)
277        if node is not None and self.add_line_information:
278            self.write_c('# line: %s' % node.lineno)
279            self.new_lines = 1
280        if self.current_statement:
281            self.statements.append(self.current_statement)
282            self.current_statement = ''
283
284    def body(self, statements):
285        if self.current_statement:
286            self.add_current_line()
287        self.new_line = True
288        self.indentation += 1
289        for stmt in statements:
290            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'):
291            #    target_name = stmt.targets[0].id # target name needed for debug only
292            self.visit(stmt)
293        self.add_current_line() # just for breaking point. to be deleted.
294        self.indentation -= 1
295
296    def body_or_else(self, node):
297        self.body(node.body)
298        if node.orelse:
299            self.newline()
300            self.write_c('else:')
301            self.body(node.orelse)
302
303    def signature(self, node):
304        want_comma = []
305        def write_comma():
306            if want_comma:
307                self.write_c(', ')
308            else:
309                want_comma.append(True)
310
311        # for C
312        for arg in node.args:
313            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
314            try:
315                arg_name = arg.arg
316            except AttributeError:
317                arg_name = arg.id
318            self.arguments.append(arg_name)
319
320        padding = [None] *(len(node.args) - len(node.defaults))
321        for arg, default in zip(node.args, padding + node.defaults):
322            if default is not None:
323                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg
324                try:
325                    arg_name = arg.arg
326                except AttributeError:
327                    arg_name = arg.id
328                w_str = ("Default Parameters are unknown to C: '%s = %s" \
329                        % (arg_name, str(default.n)))
330                self.warnings.append(w_str)
331
332    def decorators(self, node):
333        for decorator in node.decorator_list:
334            self.newline(decorator)
335            self.write_python('@')
336            self.visit(decorator)
337
338    # Statements
339
340    def visit_Assert(self, node):
341        self.newline(node)
342        self.write_c('assert ')
343        self.visit(node.test)
344        if node.msg is not None:
345            self.write_python(', ')
346            self.visit(node.msg)
347
348    def define_c_vars(self, target):
349        if hasattr(target, 'id'):
350        # a variable is considered an array if it apears in the agrument list
351        # and being assigned to. For example, the variable p in the following
352        # sniplet is a pointer, while q is not
353        # def somefunc(p, q):
354        #  p = q + 1
355        #  return
356        #
357            if target.id not in self.c_vars:
358                if target.id in self.arguments:
359                    idx = self.arguments.index(target.id)
360                    new_target = self.arguments[idx] + "[0]"
361                    if new_target not in self.c_pointers:
362                        target.id = new_target
363                        self.c_pointers.append(self.arguments[idx])
364                else:
365                    self.c_vars.append(target.id)
366
367    def add_semi_colon(self):
368        semi_pos = self.current_statement.find(';')
369        if semi_pos > 0.0:
370            self.current_statement = self.current_statement.replace(';', '')
371        self.write_c(';')
372
373    def visit_Assign(self, node):
374        self.add_current_line()
375        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
376            if idx:
377                self.write_c(' = ')
378            self.define_c_vars(target)
379            self.visit(target)
380            if hasattr(target,'id'):
381                self.assign_target = target.id
382        # Capture assigned tuple names, if any
383        targets = self.tuples[:]
384        del self.tuples[:]
385        self.write_c(' = ')
386        self.is_sequence = False
387        self.visited_args = False
388        self.visit(node.value)
389        if len(self.assign_c_vector):
390            for c_statement in self.assign_c_vector:
391                self.add_c_line (c_statement)
392#                self.c_proc.append(c_statement)
393            self.assign_c_vector.clear()
394        else:
395            self.add_semi_colon()
396            self.add_current_line()
397        # Assign tuples to tuples, if any
398        # TODO: doesn't handle swap:  a,b = b,a
399            for target, item in zip(targets, self.tuples):
400                self.visit(target)
401                self.write_c(' = ')
402                self.visit(item)
403                self.add_semi_colon()
404                self.add_current_line()
405            if self.is_sequence and not self.visited_args:
406                for target in node.targets:
407                    if hasattr(target, 'id'):
408                        if target.id in self.c_vars and target.id not in self.c_dcl_pointers:
409                            if target.id not in self.c_dcl_pointers:
410                                self.c_dcl_pointers.append(target.id)
411                                if target.id in self.c_vars:
412                                    self.c_vars.remove(target.id)
413        self.current_statement = ''
414        self.assign_target = ''
415
416    def visit_AugAssign(self, node):
417        if node.target.id not in self.c_vars:
418            if node.target.id not in self.arguments:
419                self.c_vars.append(node.target.id)
420        self.visit(node.target)
421        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
422        self.visit(node.value)
423        self.add_semi_colon()
424        self.add_current_line()
425
426    def visit_ImportFrom(self, node):
427        return  # import ignored
428        self.newline(node)
429        self.write_python('from %s%s import ' %('.' * node.level, node.module))
430        for idx, item in enumerate(node.names):
431            if idx:
432                self.write_python(', ')
433            self.write_python(item)
434
435    def visit_Import(self, node):
436        return  # import ignored
437        self.newline(node)
438        for item in node.names:
439            self.write_python('import ')
440            self.visit(item)
441
442    def visit_Expr(self, node):
443        self.newline(node)
444        self.generic_visit(node)
445
446    def write_c_pointers(self, start_var):
447        if self.c_dcl_pointers:
448            var_list = []
449            for c_ptr in self.c_dcl_pointers:
450                if c_ptr not in self.arguments:
451                    var_list.append("*" + c_ptr)
452                if c_ptr in self.c_vars:
453                    self.c_vars.remove(c_ptr)
454            if var_list:
455                c_dcl = "    double " + ", ".join(var_list) + ";\n"
456                self.c_proc.insert(start_var, c_dcl)
457                start_var += 1
458        return start_var
459
460    def insert_c_vars(self, start_var):
461        have_decls = False
462        start_var = self.write_c_pointers(start_var)
463        if self.c_int_vars:
464            for var in self.c_int_vars:
465                if var in self.c_vars:
466                    self.c_vars.remove(var)
467            decls = ", ".join(self.c_int_vars)
468            self.c_proc.insert(start_var, "    int " + decls + ";\n")
469            have_decls = True
470            start_var += 1
471        if len(self.C_Vectors) > 0:
472            vec_declare = []
473            for c_vec in self.C_Vectors:
474                item_declare = c_vec.item_declare_string()
475                if item_declare not in vec_declare:
476                    vec_declare.append (item_declare)
477                if c_vec.name in self.c_vars:
478                    self.c_vars.remove(c_vec.name)
479            all_declare = "    double " + ", ".join (vec_declare) + ";\n"
480            self.c_proc.insert(start_var, all_declare)
481            start_var += 1
482
483        if self.c_vars:
484            decls = ", ".join(self.c_vars)
485            self.c_proc.insert(start_var, "    double " + decls + ";\n")
486            have_decls = True
487            start_var += 1
488
489        if self.c_vectors:
490            for vec_number, vec_value  in enumerate(self.c_vectors):
491                name = "vec" + str(vec_number + 1)
492                decl = "    double " + name + "[] = {" + vec_value + "};"
493                self.c_proc.insert(start_var, decl + "\n")
494                start_var += 1
495
496        del self.c_vars[:]
497        del self.c_int_vars[:]
498        del self.c_vectors[:]
499        del self.c_pointers[:]
500        del self.c_dcl_pointers[:]
501        if have_decls:
502            self.c_proc.insert(start_var, "\n")
503
504    def insert_signature(self):
505        arg_decls = []
506        for arg in self.arguments:
507            decl = "double " + arg
508            if arg in self.c_pointers:
509                decl += "[]"
510            arg_decls.append(decl)
511        args_str = ", ".join(arg_decls)
512        method_sig = 'double ' + self.current_function + '(' + args_str + ")"
513        if self.signature_line >= 0:
514            self.c_proc.insert(self.signature_line, method_sig)
515
516    def visit_FunctionDef(self, node):
517        if self.current_function:
518            self.unsupported(node, "function within a function")
519        self.current_function = node.name
520
521        warning_index = len(self.warnings)
522        self.newline(extra=1)
523        self.decorators(node)
524        self.newline(node)
525        self.arguments = []
526        self.visit(node.args)
527        # for C
528        self.signature_line = len(self.c_proc)
529        self.add_c_line("\n{")
530        start_vars = len(self.c_proc) + 1
531        self.body(node.body)
532        self.add_c_line("}\n")
533        self.insert_signature()
534        self.insert_c_vars(start_vars)
535#        self.c_pointers = []
536        self.c_pointers.clear()
537        self.C_Vectors.clear()
538        self.current_function = ""
539        if (warning_index != len(self.warnings)):
540            self.warnings.insert(warning_index, "Warning in function '" + node.name + "':")
541            self.warnings.append("Note: C compilation will fail")
542
543
544    def visit_ClassDef(self, node):
545        have_args = []
546        def paren_or_comma():
547            if have_args:
548                self.write_python(', ')
549            else:
550                have_args.append(True)
551                self.write_python('(')
552
553        self.newline(extra=2)
554        self.decorators(node)
555        self.newline(node)
556        self.write_python('class %s' % node.name)
557        for base in node.bases:
558            paren_or_comma()
559            self.visit(base)
560        # CRUFT: python 2.6 does not have "keywords" attribute
561        if hasattr(node, 'keywords'):
562            for keyword in node.keywords:
563                paren_or_comma()
564                self.write_python(keyword.arg + '=')
565                self.visit(keyword.value)
566            if node.starargs is not None:
567                paren_or_comma()
568                self.write_python('*')
569                self.visit(node.starargs)
570            if node.kwargs is not None:
571                paren_or_comma()
572                self.write_python('**')
573                self.visit(node.kwargs)
574        self.write_python(have_args and '):' or ':')
575        self.body(node.body)
576
577    def visit_If(self, node):
578        self.add_current_line()
579        self.inside_if = True
580        self.write_c('if ')
581        self.visit(node.test)
582        self.write_c(' {')
583        self.body(node.body)
584        self.add_c_line('}')
585        while True:
586            else_ = node.orelse
587            if len(else_) == 0:
588                break
589            #elif hasattr(else_, 'orelse'):
590            elif len(else_) == 1 and isinstance(else_[0], ast.If):
591                node = else_[0]
592                #self.newline()
593                self.write_c('else if ')
594                self.visit(node.test)
595                self.write_c(' {')
596                self.body(node.body)
597                self.add_current_line()
598                self.add_c_line('}')
599                #break
600            else:
601                self.newline()
602                self.write_c('else {')
603                self.body(node.orelse)
604                self.add_c_line('}')
605                break
606        self.inside_if = False
607
608    def get_for_range(self, node):
609        stop = ""
610        start = '0'
611        step = '1'
612        for_args = []
613        temp_statement = self.current_statement
614        self.current_statement = ''
615        for arg in node.iter.args:
616            self.visit(arg)
617            for_args.append(self.current_statement)
618            self.current_statement = ''
619        self.current_statement = temp_statement
620        if len(for_args) == 1:
621            stop = for_args[0]
622        elif len(for_args) == 2:
623            start = for_args[0]
624            stop = for_args[1]
625        elif len(for_args) == 3:
626            start = for_args[0]
627            stop = for_args[1]
628            start = for_args[2]
629        else:
630            raise("Ilegal for loop parameters")
631        return start, stop, step
632
633    def add_c_int_var (self, iterator):
634        if iterator not in self.c_int_vars:
635            self.c_int_vars.append(iterator)
636
637    def visit_For(self, node):
638        # node: for iterator is stored in node.target.
639        # Iterator name is in node.target.id.
640        self.add_current_line()
641        fForDone = False
642        self.current_statement = ''
643        if hasattr(node.iter, 'func'):
644            if hasattr(node.iter.func, 'id'):
645                if node.iter.func.id == 'range':
646                    self.visit(node.target)
647                    iterator = self.current_statement
648                    self.current_statement = ''
649                    self.add_c_int_var (iterator)
650                    start, stop, step = self.get_for_range(node)
651                    self.write_c("for(" + iterator + "=" + str(start) +
652                                 " ; " + iterator + " < " + str(stop) +
653                                 " ; " + iterator + " += " + str(step) + ") {")
654                    self.body_or_else(node)
655                    self.write_c("}")
656                    fForDone = True
657        if not fForDone:
658            # Generate the statement that is causing the error
659            self.current_statement = ''
660            self.write_c('for ')
661            self.visit(node.target)
662            self.write_c(' in ')
663            self.visit(node.iter)
664            self.write_c(':')
665            # report the error
666            self.unsupported("unsupported " + self.current_statement)
667
668    def visit_While(self, node):
669        self.newline(node)
670        self.write_c('while ')
671        self.visit(node.test)
672        self.write_c(' { ')
673        self.body_or_else(node)
674        self.write_c('}')
675        self.add_current_line()
676
677    def visit_With(self, node):
678        self.unsupported(node)
679        self.newline(node)
680        self.write_python('with ')
681        self.visit(node.context_expr)
682        if node.optional_vars is not None:
683            self.write_python(' as ')
684            self.visit(node.optional_vars)
685        self.write_python(':')
686        self.body(node.body)
687
688    def visit_Pass(self, node):
689        self.newline(node)
690        #self.write_python('pass')
691
692    def visit_Print(self, node):
693        # TODO: print support would be nice, though hard to do
694        self.unsupported(node)
695        # CRUFT: python 2.6 only
696        self.newline(node)
697        self.write_c('print ')
698        want_comma = False
699        if node.dest is not None:
700            self.write_c(' >> ')
701            self.visit(node.dest)
702            want_comma = True
703        for value in node.values:
704            if want_comma:
705                self.write_c(', ')
706            self.visit(value)
707            want_comma = True
708        if not node.nl:
709            self.write_c(',')
710
711    def visit_Delete(self, node):
712        self.unsupported(node)
713        self.newline(node)
714        self.write_python('del ')
715        for idx, target in enumerate(node):
716            if idx:
717                self.write_python(', ')
718            self.visit(target)
719
720    def visit_TryExcept(self, node):
721        self.unsupported(node)
722        self.newline(node)
723        self.write_python('try:')
724        self.body(node.body)
725        for handler in node.handlers:
726            self.visit(handler)
727
728    def visit_TryFinally(self, node):
729        self.unsupported(node)
730        self.newline(node)
731        self.write_python('try:')
732        self.body(node.body)
733        self.newline(node)
734        self.write_python('finally:')
735        self.body(node.finalbody)
736
737    def visit_Global(self, node):
738        self.unsupported(node)
739        self.newline(node)
740        self.write_python('global ' + ', '.join(node.names))
741
742    def visit_Nonlocal(self, node):
743        self.newline(node)
744        self.write_python('nonlocal ' + ', '.join(node.names))
745
746    def visit_Return(self, node):
747        self.add_current_line()
748        if node.value is None:
749            self.write_c('return')
750        else:
751            self.write_c('return(')
752            self.visit(node.value)
753        self.write_c(')')
754        self.add_semi_colon()
755        self.add_c_line(self.current_statement)
756        self.current_statement = ''
757
758    def visit_Break(self, node):
759        self.newline(node)
760        self.write_c('break')
761
762    def visit_Continue(self, node):
763        self.newline(node)
764        self.write_c('continue')
765
766    def visit_Raise(self, node):
767        self.unsupported(node)
768        # CRUFT: Python 2.6 / 3.0 compatibility
769        self.newline(node)
770        self.write_python('raise')
771        if hasattr(node, 'exc') and node.exc is not None:
772            self.write_python(' ')
773            self.visit(node.exc)
774            if node.cause is not None:
775                self.write_python(' from ')
776                self.visit(node.cause)
777        elif hasattr(node, 'type') and node.type is not None:
778            self.visit(node.type)
779            if node.inst is not None:
780                self.write_python(', ')
781                self.visit(node.inst)
782            if node.tback is not None:
783                self.write_python(', ')
784                self.visit(node.tback)
785
786    # Expressions
787
788    def visit_Attribute(self, node):
789        self.unsupported(node, "attribute reference a.b not supported")
790        self.visit(node.value)
791        self.write_python('.' + node.attr)
792
793    def visit_Call(self, node):
794        want_comma = []
795        def write_comma():
796            if want_comma:
797                self.write_c(', ')
798            else:
799                want_comma.append(True)
800        if hasattr(node.func, 'id'):
801            if node.func.id not in self.c_functions:
802                self.c_functions.append(node.func.id)
803            if node.func.id == 'abs':
804                self.write_c("fabs ")
805            elif node.func.id == 'int':
806                self.write_c('(int) ')
807            elif node.func.id == "SINCOS":
808                self.write_sincos(node)
809                return
810            elif node.func.id == "print":
811                self.write_print(node)
812                return
813            else:
814                self.visit(node.func)
815        else:
816            self.visit(node.func)
817        self.write_c('(')
818        for arg in node.args:
819            write_comma()
820            self.visited_args = True
821            self.visit(arg)
822        for keyword in node.keywords:
823            write_comma()
824            self.write_c(keyword.arg + '=')
825            self.visit(keyword.value)
826        if hasattr(node, 'starargs'):
827            if node.starargs is not None:
828                write_comma()
829                self.write_c('*')
830                self.visit(node.starargs)
831        if hasattr(node, 'kwargs'):
832            if node.kwargs is not None:
833                write_comma()
834                self.write_c('**')
835                self.visit(node.kwargs)
836        self.write_c(')')
837        if (self.inside_if == False):
838            self.write_c(';')
839
840    def visit_Name(self, node):
841        self.write_c(node.id)
842        if node.id in self.c_pointers and not self.in_subref:
843            self.write_c("[0]")
844        name = ""
845        sub = node.id.find("[")
846        if sub > 0:
847            name = node.id[0:sub].strip()
848        else:
849            name = node.id
850        # add variable to c_vars if it ins't there yet, not an argument and not a number
851        if (name not in self.c_functions and name not in self.c_vars and
852                name not in self.c_int_vars and name not in self.arguments and
853                name not in self.c_constants and not name.isdigit()):
854            if self.in_subscript:
855                self.add_c_int_var (self, node.id)
856            else:
857                self.c_vars.append(node.id)
858
859    def visit_Str(self, node):
860        self.write_c(repr(node.s))
861
862    def visit_Bytes(self, node):
863        self.write_c(repr(node.s))
864
865    def visit_Num(self, node):
866        self.write_c(repr(node.n))
867
868    def visit_Tuple(self, node):
869        for idx, item in enumerate(node.elts):
870            if idx:
871                self.tuples.append(item)
872            else:
873                self.visit(item)
874
875    def sequence_visit(left, right):
876        def visit(self, node):
877            self.is_sequence = True
878            c_vec = C_Vector()
879            s = ""
880            for idx, item in enumerate(node.elts):
881                if idx > 0 and s:
882                    s += ', '
883                if hasattr(item, 'id'):
884                    c_vec.values.append(item.id)
885#                    s += item.id
886                elif hasattr(item, 'n'):
887                    c_vec.values.append(item.n)
888#                    s += str(item.n)
889                else:
890                    temp = self.current_statement
891                    self.current_statement = ""
892                    self.visit(item)
893                    c_vec.values.append(self.current_statement)
894#                    s += self.current_statement
895                    self.current_statement = temp
896            if (self.assign_target):
897                self.add_c_int_var ('n')
898                c_vec.name = self.assign_target
899                self.C_Vectors.append(c_vec)
900                assign_strings = c_vec.get_assign()
901                for assign_str in assign_strings:
902                    self.assign_c_vector.append(assign_str)
903#                    self.add_c_line(assign_str)
904#            if s:
905#                self.c_vectors.append(s)
906#                vec_name = "vec"  + str(len(self.c_vectors))
907#                self.write_c(vec_name)
908        return visit
909
910    visit_List = sequence_visit('[', ']')
911    visit_Set = sequence_visit('{', '}')
912    del sequence_visit
913
914    def visit_Dict(self, node):
915        self.unsupported(node)
916        self.write_python('{')
917        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
918            if idx:
919                self.write_python(', ')
920            self.visit(key)
921            self.write_python(': ')
922            self.visit(value)
923        self.write_python('}')
924
925    def get_special_power(self, string):
926        function_name = ''
927        is_negative_exp = False
928        if isevaluable(str(self.current_statement)):
929            exponent = eval(string)
930            is_negative_exp = exponent < 0
931            abs_exponent = abs(exponent)
932            if abs_exponent == 2:
933                function_name = "square"
934            elif abs_exponent == 3:
935                function_name = "cube"
936            elif abs_exponent == 0.5:
937                function_name = "sqrt"
938            elif abs_exponent == 1.0/3.0:
939                function_name = "cbrt"
940        if function_name == '':
941            function_name = "pow"
942        return function_name, is_negative_exp
943
944    def translate_power(self, node):
945        # get exponent by visiting the right hand argument.
946        function_name = "pow"
947        temp_statement = self.current_statement
948        # 'visit' functions write the results to the 'current_statement' class memnber
949        # Here, a temporary variable, 'temp_statement', is used, that enables the
950        # use of the 'visit' function
951        self.current_statement = ''
952        self.visit(node.right)
953        exponent = self.current_statement.replace(' ', '')
954        function_name, is_negative_exp = self.get_special_power(self.current_statement)
955        self.current_statement = temp_statement
956        if is_negative_exp:
957            self.write_c("1.0 /(")
958        self.write_c(function_name + "(")
959        self.visit(node.left)
960        if function_name == "pow":
961            self.write_c(", ")
962            self.visit(node.right)
963        self.write_c(")")
964        if is_negative_exp:
965            self.write_c(")")
966        self.write_c(" ")
967
968    def translate_integer_divide(self, node):
969        self.write_c("(int)(")
970        self.visit(node.left)
971        self.write_c(") /(int)(")
972        self.visit(node.right)
973        self.write_c(")")
974
975    def visit_BinOp(self, node):
976        self.write_c("(")
977        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]:
978            self.translate_power(node)
979        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]:
980            self.translate_integer_divide(node)
981        else:
982            self.visit(node.left)
983            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
984            self.visit(node.right)
985        self.write_c(")")
986
987    # for C
988    def visit_BoolOp(self, node):
989        self.write_c('(')
990        for idx, value in enumerate(node.values):
991            if idx:
992                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
993            self.visit(value)
994        self.write_c(')')
995
996    def visit_Compare(self, node):
997        self.write_c('(')
998        self.visit(node.left)
999        for op, right in zip(node.ops, node.comparators):
1000            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
1001            self.visit(right)
1002        self.write_c(')')
1003
1004    def visit_UnaryOp(self, node):
1005        self.write_c('(')
1006        op = UNARYOP_SYMBOLS[type(node.op)]
1007        self.write_c(op)
1008        if op == 'not':
1009            self.write_c(' ')
1010        self.visit(node.operand)
1011        self.write_c(')')
1012
1013    def visit_Subscript(self, node):
1014        if node.value.id not in self.c_constants:
1015            if node.value.id not in self.c_pointers:
1016                self.c_pointers.append(node.value.id)
1017        self.in_subref = True
1018        self.visit(node.value)
1019        self.in_subref = False
1020        self.write_c('[')
1021        self.in_subscript = True
1022        self.visit(node.slice)
1023        self.in_subscript = False
1024        self.write_c(']')
1025
1026    def visit_Slice(self, node):
1027        if node.lower is not None:
1028            self.visit(node.lower)
1029        self.write_python(':')
1030        if node.upper is not None:
1031            self.visit(node.upper)
1032        if node.step is not None:
1033            self.write_python(':')
1034            if not(isinstance(node.step, Name) and node.step.id == 'None'):
1035                self.visit(node.step)
1036
1037    def visit_ExtSlice(self, node):
1038        for idx, item in node.dims:
1039            if idx:
1040                self.write_python(', ')
1041            self.visit(item)
1042
1043    def visit_Yield(self, node):
1044        self.unsupported(node)
1045        self.write_python('yield ')
1046        self.visit(node.value)
1047
1048    def visit_Lambda(self, node):
1049        self.unsupported(node)
1050        self.write_python('lambda ')
1051        self.visit(node.args)
1052        self.write_python(': ')
1053        self.visit(node.body)
1054
1055    def visit_Ellipsis(self, node):
1056        self.unsupported(node)
1057        self.write_python('Ellipsis')
1058
1059    def generator_visit(left, right):
1060        def visit(self, node):
1061            self.write_python(left)
1062            self.write_c(left)
1063            self.visit(node.elt)
1064            for comprehension in node.generators:
1065                self.visit(comprehension)
1066            self.write_c(right)
1067            #self.write_python(right)
1068        return visit
1069
1070    visit_ListComp = generator_visit('[', ']')
1071    visit_GeneratorExp = generator_visit('(', ')')
1072    visit_SetComp = generator_visit('{', '}')
1073    del generator_visit
1074
1075    def visit_DictComp(self, node):
1076        self.unsupported(node)
1077        self.write_python('{')
1078        self.visit(node.key)
1079        self.write_python(': ')
1080        self.visit(node.value)
1081        for comprehension in node.generators:
1082            self.visit(comprehension)
1083        self.write_python('}')
1084
1085    def visit_NameConstant (self, node):
1086        if (hasattr (node, "value")):
1087            if (node.value == True):
1088                val = "1"
1089            elif node.value == False:
1090                val = "0"
1091            else:
1092                val = ""
1093            if (len(val) > 0):
1094                self.write_c("(" + val + ")")
1095
1096    def visit_IfExp(self, node):
1097        self.write_c('(')
1098        self.visit(node.test)
1099        self.write_c(' ? ')
1100        self.visit(node.body)
1101        self.write_c(' : ')
1102        self.visit(node.orelse)
1103        self.write_c(');')
1104
1105    def visit_Starred(self, node):
1106        self.write_c('*')
1107        self.visit(node.value)
1108
1109    def visit_Repr(self, node):
1110        # CRUFT: python 2.6 only
1111        self.write_c('`')
1112        self.visit(node.value)
1113        self.write_python('`')
1114
1115    # Helper Nodes
1116
1117    def visit_alias(self, node):
1118        self.unsupported(node)
1119        self.write_python(node.name)
1120        if node.asname is not None:
1121            self.write_python(' as ' + node.asname)
1122
1123    def visit_comprehension(self, node):
1124        self.write_c(' for ')
1125        self.visit(node.target)
1126        self.write_C(' in ')
1127        #self.write_python(' in ')
1128        self.visit(node.iter)
1129        if node.ifs:
1130            for if_ in node.ifs:
1131                self.write_python(' if ')
1132                self.visit(if_)
1133
1134    def visit_arguments(self, node):
1135        self.signature(node)
1136
1137    def unsupported(self, node, message=None):
1138        if hasattr(node, "value"):
1139            lineno = node.value.lineno
1140        elif hasattr(node, "iter"):
1141            lineno = node.iter.lineno
1142        else:
1143            #print(dir(node))
1144            lineno = 0
1145
1146        lineno += self.lineno_offset
1147        if self.fname:
1148            location = "%s(%d)" % (self.fname, lineno)
1149        else:
1150            location = "%d" % (self.fname, lineno)
1151        if self.current_function:
1152            location += ", function %s" % self.current_function
1153        if message is None:
1154            message = node.__class__.__name__ + " syntax not supported"
1155        raise SyntaxError("[%s] %s" % (location, message))
1156
1157def print_function(f=None):
1158    """
1159    Print out the code for the function
1160    """
1161    # Include some comments to see if they get printed
1162    import ast
1163    import inspect
1164    if f is not None:
1165        tree = ast.parse(inspect.getsource(f))
1166        tree_source = to_source(tree)
1167        print(tree_source)
1168
1169def define_constant(name, value, block_size=1):
1170    # type: (str, any, int) -> str
1171    """
1172    Convert a python constant into a C constant of the same name.
1173
1174    Returns the C declaration of the constant as a string, possibly containing
1175    line feeds.  The string will not be indented.
1176
1177    Supports int, double and sequences of double.
1178    """
1179    const = "constant "  # OpenCL needs globals to be constant
1180    if isinstance(value, int):
1181        parts = [const + "int ", name, " = ", "%d"%value, ";"]
1182    elif isinstance(value, float):
1183        parts = [const + "double ", name, " = ", "%.15g"%value, ";"]
1184    else:
1185        try:
1186            len(value)
1187        except TypeError:
1188            raise TypeError("constant %s must be int, float or [float, ...]"%name)
1189        # extend constant arrays to a multiple of 4; not sure if this
1190        # is necessary, but some OpenCL targets broke if the number
1191        # of parameters in the parameter table was not a multiple of 4,
1192        # so do it for all constant arrays to be safe.
1193        if len(value)%block_size != 0:
1194            value = list(value) + [0.]*(block_size - len(value)%block_size)
1195        elements = ["%.15g"%v for v in value]
1196        parts = [const + "double ", name, "[]", " = ",
1197                 "{\n   ", ", ".join(elements), "\n};"]
1198
1199    return "".join(parts)
1200
1201
1202# Modified from the following:
1203#
1204#    http://code.activestate.com/recipes/578272-topological-sort/
1205#    Copyright (C) 2012 Sam Denton
1206#    License: MIT
1207def ordered_dag(dag):
1208    # type: (Dict[T, Set[T]]) -> Iterator[T]
1209    """
1210    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys
1211    in order such that every key occurs after the keys it depends upon.
1212
1213    This is an iterator not a sequence.  To reverse it use::
1214
1215        reversed(tuple(ordered_dag(dag)))
1216
1217    Raise an error if there are any cycles.
1218
1219    Keys are arbitrary hashable values.
1220    """
1221    # Local import to make the function stand-alone, and easier to borrow
1222    from functools import reduce
1223
1224    dag = dag.copy()
1225
1226    # make leaves depend on the empty set
1227    leaves = reduce(set.union, dag.values()) - set(dag.keys())
1228    dag.update({node: set() for node in leaves})
1229    while True:
1230        leaves = set(node for node, links in dag.items() if not links)
1231        if not leaves:
1232            break
1233        for node in leaves:
1234            yield node
1235        dag = {node: (links-leaves)
1236               for node, links in dag.items() if node not in leaves}
1237    if dag:
1238        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
1239                         % ", ".join(str(node) for node in dag.keys()))
1240
1241def translate(functions, constants=None):
1242    # type: (List[(str, str, int)], Dict[str, any]) -> List[str]
1243    """
1244    Convert a set of functions
1245    """
1246    snippets = []
1247    snippets.append("#include <math.h>")
1248    snippets.append("")
1249    warnings = []
1250    for source, fname, lineno in functions:
1251        line_directive = '#line %d "%s"'%(lineno, fname.replace('\\', '\\\\'))
1252        snippets.append(line_directive)
1253        tree = ast.parse(source)
1254        c_code,w = to_source(tree, constants=constants, fname=fname, lineno=lineno)
1255        if w:
1256            warnings.append ("\n".join(w))
1257        snippets.append(c_code)
1258    return "\n".join(snippets),warnings
1259#    return snippets
1260
1261def print_usage ():
1262        print("""\
1263            Usage: python py2c.py <infile> [<outfile>]
1264            if outfile is omitted, output file is '<infile>.c'
1265        """)
1266
1267def get_files_names ():
1268    import os
1269    fname_in = fname_out = ""
1270    valid_params =  len(sys.argv) > 1
1271    if (valid_params == False):
1272        print_usage()
1273    else:
1274        fname_in = sys.argv[1]
1275        if len(sys.argv) == 2:
1276            fname_base = os.path.splitext(fname_in)[0]
1277            fname_out = str(fname_base) + '.c'
1278        else:
1279            fname_out = sys.argv[2]
1280    return valid_params, fname_in,fname_out
1281
1282def main():
1283    print("Parsing...using Python" + sys.version)
1284    valid_params, fname_in, fname_out = get_files_names ()
1285    if (valid_params == False):
1286        print("Input parameters error.\nExiting")
1287        return
1288
1289    with open(fname_in, "r") as python_file:
1290            code = python_file.read()
1291    name = "gauss"
1292    code = (code
1293            .replace(name+'.n', 'GAUSS_N')
1294            .replace(name+'.z', 'GAUSS_Z')
1295            .replace(name+'.w', 'GAUSS_W'))
1296
1297    translation,warnings = translate([(code, fname_in, 1)])
1298
1299    with open(fname_out, "w") as file_out:
1300        file_out.write(str(translation))
1301    if warnings:
1302        print("".join(str(warnings[0])))
1303    print("...Done")
1304
1305if __name__ == "__main__":
1306    try:
1307        main()
1308    except Exception as excp:
1309        print ("Error:\n" + str(excp.args))
Note: See TracBrowser for help on using the repository browser.