source: sasmodels/sasmodels/py2c.py @ 937afef

Last change on this file since 937afef was 937afef, checked in by Omer Eisenberg <omereis@…>, 6 years ago

fixed bug in 'for' loop translation

  • Property mode set to 100644
File size: 38.1 KB
Line 
1
2"""
3    codegen
4    ~~~~~~~
5
6    Extension to ast that allow ast -> python code generation.
7
8    :copyright: Copyright 2008 by Armin Ronacher.
9    :license: BSD.
10"""
11"""
12    Variables definition in C
13    -------------------------
14    Defining variables within the Translate function is a bit of a guess work, using following rules.
15    *   By default, a variable is a 'double'.
16    *   Variable in a for loop is an int.
17    *   Variable that is references with brackets is an array of doubles. The variable within the brackets
18        is integer. For example, in the reference 'var1[var2]', var1 is a double array, and var2 is an integer.
19    *   Assignment to an argument makes that argument an array, and the index in that assignment is 0.
20        For example, the following python code
21            def func (arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code
24            double func (double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the following C code
29            def func (arg1, arg2):          double func (double arg1) {
30                arg2 = 17.                      arg2[0] = 17.0;
31                                            }
32    *   All functions are defined as double, even if there is no return statement.
33
34
35Update Notes
36============
3711/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It shold insert 4 blanks per indentation level.
38                    The 'body' method will combine all the strings, by adding the 'current_statement' to the c_proc string list
39   11/2017, OE: variables, argument definition implemented.
40   Note: An argument is considered an array if it is the target of an assignment. In that case it is translated to <var>[0]
4111/27/2017, OE: 'pow' basicly working
42  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
43  /12/2017, OE: Power function, including special cases of square(x) (pow(x,2)) and cube(x) (pow(x,3)), implemented in translate_power,
44                called from visit_BinOp
4512/07/2017, OE: Translation of integer division, '\\' in python, implemented in translate_integer_divide, called from visit_BinOp
4612/07/2017, OE: C variable definition handled in 'define_C_Vars'
47              : Python integer division, '//', translated to C in 'translate_integer_divide'
48"""
49import ast
50import sys
51from ast import NodeVisitor
52
53BINOP_SYMBOLS = {}
54BINOP_SYMBOLS[ast.Add] = '+'
55BINOP_SYMBOLS[ast.Sub] = '-'
56BINOP_SYMBOLS[ast.Mult] = '*'
57BINOP_SYMBOLS[ast.Div] = '/'
58BINOP_SYMBOLS[ast.Mod] = '%'
59BINOP_SYMBOLS[ast.Pow] = '**'
60BINOP_SYMBOLS[ast.LShift] = '<<'
61BINOP_SYMBOLS[ast.RShift] = '>>'
62BINOP_SYMBOLS[ast.BitOr] = '|'
63BINOP_SYMBOLS[ast.BitXor] = '^'
64BINOP_SYMBOLS[ast.BitAnd] = '&'
65BINOP_SYMBOLS[ast.FloorDiv] = '//'
66
67BOOLOP_SYMBOLS = {}
68BOOLOP_SYMBOLS[ast.And] = '&&'
69BOOLOP_SYMBOLS[ast.Or]  = '||'
70
71CMPOP_SYMBOLS = {}
72CMPOP_SYMBOLS[ast.Eq] = '=='
73CMPOP_SYMBOLS[ast.NotEq] = '!='
74CMPOP_SYMBOLS[ast.Lt] = '<'
75CMPOP_SYMBOLS[ast.LtE] = '<='
76CMPOP_SYMBOLS[ast.Gt] = '>'
77CMPOP_SYMBOLS[ast.GtE] = '>='
78CMPOP_SYMBOLS[ast.Is] = 'is'
79CMPOP_SYMBOLS[ast.IsNot] = 'is not'
80CMPOP_SYMBOLS[ast.In] = 'in'
81CMPOP_SYMBOLS[ast.NotIn] = 'not in'
82
83UNARYOP_SYMBOLS = {}
84UNARYOP_SYMBOLS[ast.Invert] = '~'
85UNARYOP_SYMBOLS[ast.Not] = 'not'
86UNARYOP_SYMBOLS[ast.UAdd] = '+'
87UNARYOP_SYMBOLS[ast.USub] = '-'
88
89
90#def to_source(node, indent_with=' ' * 4, add_line_information=False):
91def to_source(node, func_name):
92    """This function can convert a node tree back into python sourcecode.
93    This is useful for debugging purposes, especially if you're dealing with
94    custom asts not generated by python itself.
95
96    It could be that the sourcecode is evaluable when the AST itself is not
97    compilable / evaluable.  The reason for this is that the AST contains some
98    more data than regular sourcecode does, which is dropped during
99    conversion.
100
101    Each level of indentation is replaced with `indent_with`.  Per default this
102    parameter is equal to four spaces as suggested by PEP 8, but it might be
103    adjusted to match the application's styleguide.
104
105    If `add_line_information` is set to `True` comments for the line numbers
106    of the nodes are added to the output.  This can be used to spot wrong line
107    number information of statement nodes.
108    """
109    generator = SourceGenerator(' ' * 4, False)
110#    generator.required_functions = func_name
111    generator.visit(node)
112
113#    return ''.join(generator.result)
114    return ''.join(generator.c_proc)
115
116def isevaluable(s):
117    try:
118        eval(s)
119        return True
120    except:
121        return False
122   
123class SourceGenerator(NodeVisitor):
124    """This visitor is able to transform a well formed syntax tree into python
125    sourcecode.  For more details have a look at the docstring of the
126    `node_to_source` function.
127    """
128
129    def __init__(self, indent_with, add_line_information=False):
130        self.result = []
131        self.indent_with = indent_with
132        self.add_line_information = add_line_information
133        self.indentation = 0
134        self.new_lines = 0
135        self.c_proc = []
136# for C
137        self.signature_line = 0
138        self.arguments = []
139        self.name = ""
140        self.warnings = []
141        self.Statements = []
142        self.current_statement = ""
143        self.strMethodSignature = ""
144        self.C_Vars = []
145        self.C_IntVars = []
146        self.MathIncludeed = False
147        self.C_Pointers = []
148        self.C_DclPointers = []
149        self.C_Functions = []
150        self.C_Vectors = []
151        self.SubRef = False
152        self.InSubscript = False
153        self.Tuples = []
154        self.required_functions = []
155        self.is_sequence = False
156        self.visited_args = False
157
158    def write_python(self, x):
159        if self.new_lines:
160            if self.result:
161                self.result.append('\n' * self.new_lines)
162            self.result.append(self.indent_with * self.indentation)
163            self.new_lines = 0
164        self.result.append(x)
165
166    def write_c(self, x):
167        self.current_statement += x
168       
169    def add_c_line(self, x):
170        string = ''
171        for i in range (self.indentation):
172            string += ("    ")
173        string += str(x)
174        self.c_proc.append (str(string + "\n"))
175        x = ''
176
177    def add_current_line (self):
178        if (len(self.current_statement) > 0):
179            self.add_c_line (self.current_statement)
180            self.current_statement = ''
181
182    def AddUniqueVar (self, new_var):
183        if ((new_var not in self.C_Vars)):
184            self.C_Vars.append (str (new_var))
185
186    def WriteSincos (self, node):
187        angle = str(node.args[0].id)
188        self.write_c (node.args[1].id + " = sin (" + angle + ");")
189        self.add_current_line()
190        self.write_c (node.args[2].id + " = cos (" + angle + ");")
191        self.add_current_line()
192        for arg in node.args:
193            self.AddUniqueVar (arg.id)
194
195    def newline(self, node=None, extra=0):
196        self.new_lines = max(self.new_lines, 1 + extra)
197        if node is not None and self.add_line_information:
198            self.write_c('# line: %s' % node.lineno)
199            self.new_lines = 1
200        if (len(self.current_statement)):
201            self.Statements.append (self.current_statement)
202            self.current_statement = ''
203
204    def body(self, statements):
205        if (len(self.current_statement)):
206            self.add_current_line ()
207        self.new_line = True
208        self.indentation += 1
209        for stmt in statements:
210            target_name = ''
211            if (hasattr (stmt, 'targets')):
212                if (hasattr(stmt.targets[0],'id')):
213                    target_name = stmt.targets[0].id # target name needed for debug only
214            self.visit(stmt)
215        self.add_current_line () # just for breaking point. to be deleted.
216        self.indentation -= 1
217
218    def body_or_else(self, node):
219        self.body(node.body)
220        if node.orelse:
221            self.newline()
222            self.write_c('else:')
223            self.body(node.orelse)
224
225    def signature(self, node):
226        want_comma = []
227        def write_comma():
228            if want_comma:
229                self.write_c(', ')
230            else:
231                want_comma.append(True)
232# for C
233        for arg in node.args:
234            self.arguments.append (arg.arg)
235
236        padding = [None] * (len(node.args) - len(node.defaults))
237        for arg, default in zip(node.args, padding + node.defaults):
238            if default is not None:
239                self.warnings.append ("Default Parameter unknown to C")
240                w_str = "Default Parameters are unknown to C: '" + arg.arg + " = " + str(default.n) + "'"
241                self.warnings.append (w_str)
242#                self.write_python('=')
243#                self.visit(default)
244
245    def decorators(self, node):
246        for decorator in node.decorator_list:
247            self.newline(decorator)
248            self.write_python('@')
249            self.visit(decorator)
250
251    # Statements
252
253    def visit_Assert(self, node):
254        self.newline(node)
255        self.write_c('assert ')
256        self.visit(node.test)
257        if node.msg is not None:
258           self.write_python(', ')
259           self.visit(node.msg)
260
261    def define_C_Vars (self, target):
262        if (hasattr (target, 'id')):
263# a variable is considered an array if it apears in the agrument list
264# and being assigned to. For example, the variable p in the following
265# sniplet is a pointer, while q is not
266# def somefunc (p, q):
267#  p = q + 1
268#  return
269#
270            if (target.id not in self.C_Vars):
271                if (target.id in self.arguments):
272                    idx = self.arguments.index (target.id)
273                    new_target = self.arguments[idx] + "[0]"
274                    if (new_target not in self.C_Pointers):
275                        target.id = new_target
276                        self.C_Pointers.append (self.arguments[idx])
277                else:
278                    self.C_Vars.append (target.id)
279
280    def add_semi_colon (self):
281        semi_pos = self.current_statement.find(';')
282        if (semi_pos < 0):
283            self.write_c(';')
284
285    def visit_Assign(self, node):
286        self.add_current_line ()
287        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
288            if idx:
289                self.write_c(' = ')
290            self.define_C_Vars (target)
291            self.visit(target)
292        if (len(self.Tuples) > 0):
293            tplTargets = list (self.Tuples)
294            self.Tuples.clear()
295        self.write_c(' = ')
296        self.is_sequence = False
297        self.visited_args = False
298        self.visit(node.value)
299        self.add_semi_colon ()
300#        self.write_c(';')
301        self.add_current_line ()
302        for n,item in enumerate (self.Tuples):
303            self.visit(tplTargets[n])
304            self.write_c(' = ')
305            self.visit(item)
306            self.add_semi_colon ()
307            self.add_current_line ()
308        if ((self.is_sequence) and (not self.visited_args)):
309            for target in node.targets:
310                if (hasattr (target, 'id')):
311                    if ((target.id in self.C_Vars) and (target.id not in self.C_DclPointers)):
312                        if (target.id not in self.C_DclPointers):
313                            self.C_DclPointers.append (target.id)
314                            if (target.id in self.C_Vars):
315                                self.C_Vars.remove(target.id)
316        self.current_statement = ''
317
318    def visit_AugAssign(self, node):
319        if (node.target.id not in self.C_Vars):
320            if (node.target.id not in self.arguments):
321                self.C_Vars.append (node.target.id)
322        self.visit(node.target)
323        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
324        self.visit(node.value)
325        self.add_semi_colon ()
326#        self.write_c(';')
327        self.add_current_line ()
328
329    def visit_ImportFrom(self, node):
330        self.newline(node)
331        self.write_python('from %s%s import ' % ('.' * node.level, node.module))
332        for idx, item in enumerate(node.names):
333            if idx:
334                self.write_python(', ')
335            self.write_python(item)
336
337    def visit_Import(self, node):
338        self.newline(node)
339        for item in node.names:
340            self.write_python('import ')
341            self.visit(item)
342
343    def visit_Expr(self, node):
344        self.newline(node)
345        self.generic_visit(node)
346
347    def listToDeclare(self, Vars):
348        s = ''
349        if (len(Vars) > 0):
350            s = ",".join (Vars)
351        return (s)
352
353    def write_C_Pointers (self, start_var):
354        if (len (self.C_DclPointers) > 0):
355            vars = ""
356            for c_ptr in self.C_DclPointers:
357                if (len(vars) > 0):
358                    vars += ", "
359                if (c_ptr not in self.arguments):
360                    vars += "*" + c_ptr
361                if (c_ptr in self.C_Vars):
362                    if (c_ptr in self.C_Vars):
363                        self.C_Vars.remove (c_ptr)
364            if (len(vars) > 0):
365                c_dcl = "    double " + vars + ";"
366                self.c_proc.insert (start_var, c_dcl + "\n")
367                start_var += 1
368        return start_var
369
370    def insert_C_Vars (self, start_var):
371        fLine = False
372        start_var = self.write_C_Pointers (start_var)
373        if (len(self.C_IntVars) > 0):
374            for var in self.C_IntVars:
375                if (var in self.C_Vars):
376                    self.C_Vars.remove(var)
377            s = self.listToDeclare(self.C_IntVars)
378            self.c_proc.insert (start_var, "    int " + s + ";\n")
379            fLine = True
380            start_var += 1
381           
382        if (len(self.C_Vars) > 0):
383            s = self.listToDeclare(self.C_Vars)
384            self.c_proc.insert (start_var, "    double " + s + ";\n")
385            fLine = True
386            start_var += 1
387#        if (len(self.C_IntVars) > 0):
388#            s = self.listToDeclare(self.C_IntVars)
389#            self.c_proc.insert (start_var, "    int " + s + ";\n")
390#            fLine = True
391#            start_var += 1
392        if (len (self.C_Vectors) > 0):
393            s = self.listToDeclare(self.C_Vectors)
394            for n in range(len(self.C_Vectors)):
395                name = "vec" + str(n+1)
396                c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};"
397                self.c_proc.insert (start_var, c_dcl + "\n")
398                start_var += 1
399        self.C_Vars.clear()
400        self.C_IntVars.clear()
401        self.C_Vectors.clear()
402        self.C_Pointers.clear()
403        self.C_DclPointers
404        if (fLine == True):
405            self.c_proc.insert (start_var, "\n")
406        return
407        s = ''
408        for n in range(len(self.C_Vars)):
409            s += str(self.C_Vars[n])
410            if n < len(self.C_Vars) - 1:
411                s += ", "
412        if (len(s) > 0):
413            self.c_proc.insert (start_var, "    double " + s + ";\n")
414            self.c_proc.insert (start_var + 1, "\n")
415
416    def writeInclude (self):
417        if (self.MathIncludeed == False):
418            self.add_c_line("#include <math.h>\n")
419            self.add_c_line("static double pi = 3.14159265359;\n")
420            self.MathIncludeed = True
421
422    def ListToString (self, strings):
423        s = ''
424        for n in range(len(strings)):
425            s += strings[n]
426            if (n < (len(strings) - 1)):
427                s += ", "
428        return (s)
429
430    def getMethodSignature (self):
431#        args_str = ListToString (self.arguments)
432        args_str = ''
433        for n in range(len(self.arguments)):
434            args_str += "double " + self.arguments[n]
435            if (n < (len(self.arguments) - 1)):
436                args_str += ", "
437        return (args_str)
438#        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
439
440    def InsertSignature (self):
441        args_str = ''
442        for n in range(len(self.arguments)):
443            args_str += "double " + self.arguments[n]
444            if (self.arguments[n] in self.C_Pointers):
445                args_str += "[]"
446            if (n < (len(self.arguments) - 1)):
447                args_str += ", "
448        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
449        if (self.signature_line >= 0):
450            self.c_proc.insert (self.signature_line, self.strMethodSignature)
451
452    def visit_FunctionDef(self, node):
453        self.newline(extra=1)
454        self.decorators(node)
455        self.newline(node)
456        self.arguments = []
457        self.name = node.name
458#        if self.name not in self.required_functions[0]:
459#           return
460        print("Parsing '" + self.name + "'")
461        args_str = ""
462
463        self.visit(node.args)
464# for C
465        self.writeInclude()
466        self.getMethodSignature ()
467# for C
468        self.signature_line = len(self.c_proc)
469#        self.add_c_line(self.strMethodSignature)
470        self.add_c_line("\n{")
471        start_vars = len(self.c_proc) + 1
472        self.body(node.body)
473        self.add_c_line("}\n")
474        self.InsertSignature ()
475        self.insert_C_Vars (start_vars)
476        self.C_Pointers = []
477
478    def visit_ClassDef(self, node):
479        have_args = []
480        def paren_or_comma():
481            if have_args:
482                self.write_python(', ')
483            else:
484                have_args.append(True)
485                self.write_python('(')
486
487        self.newline(extra=2)
488        self.decorators(node)
489        self.newline(node)
490        self.write_python('class %s' % node.name)
491        for base in node.bases:
492            paren_or_comma()
493            self.visit(base)
494        # XXX: the if here is used to keep this module compatible
495        #      with python 2.6.
496        if hasattr(node, 'keywords'):
497            for keyword in node.keywords:
498                paren_or_comma()
499                self.write_python(keyword.arg + '=')
500                self.visit(keyword.value)
501            if node.starargs is not None:
502                paren_or_comma()
503                self.write_python('*')
504                self.visit(node.starargs)
505            if node.kwargs is not None:
506                paren_or_comma()
507                self.write_python('**')
508                self.visit(node.kwargs)
509        self.write_python(have_args and '):' or ':')
510        self.body(node.body)
511
512    def visit_If(self, node):
513        self.write_c('if ')
514        self.visit(node.test)
515        self.write_c(' {')
516        self.body(node.body)
517        self.add_c_line('}')
518        while True:
519            else_ = node.orelse
520            if len(else_) == 0:
521                break
522#            elif hasattr (else_, 'orelse'):
523            elif len(else_) == 1 and isinstance(else_[0], ast.If):
524                node = else_[0]
525#                self.newline()
526                self.write_c('else if ')
527                self.visit(node.test)
528                self.write_c(' {')
529                self.body(node.body)
530                self.add_current_line ()
531                self.add_c_line('}')
532#                break
533            else:
534                self.newline()
535                self.write_c('else {')
536                self.body(node.body)
537                self.add_c_line('}')
538                break
539
540    def getNodeLineNo (self, node):
541        line_number = -1
542        if (hasattr (node,'value')):
543            line_number = node.value.lineno
544        elif hasattr(node,'iter'):
545            if hasattr(node.iter,'lineno'):
546                line_number = node.iter.lineno
547        return (line_number)
548
549    def GetNodeAsString (self, node):
550        res = ''
551        if (hasattr(node,'n')):
552            res = str(node.n)
553        elif (hasattr(node,'id')):
554            res = node.id
555        return (res)
556
557    def GetForRange(self, node):
558        stop = ""
559        start = '0'
560        step = '1'
561        for_args = []
562        temp_statement = self.current_statement
563        self.current_statement = ''
564        for arg in node.iter.args:
565            self.visit(arg)
566            for_args.append(self.current_statement)
567            self.current_statement = ''
568        self.current_statement = temp_statement
569        if (len(for_args) == 1):
570            stop = for_args[0]
571        elif (len(for_args) == 2):
572            start = for_args[0]
573            stop = for_args[1]
574        elif (len(for_args) == 3):
575            start = for_args[0]
576            stop = for_args[1]
577            start = for_args[2]
578        else:
579            raise("Ilegal for loop parameters")
580        return (start, stop, step)
581
582    def visit_For(self, node):
583# node: for iterator is stored in node.target.
584# Iterator name is in node.target.id.
585        self.add_current_line()
586#        if (len(self.current_statement) > 0):
587#            self.add_c_line(self.current_statement)
588#            self.current_statement = ''
589        fForDone = False
590        self.current_statement = ''
591        if (hasattr(node.iter,'func')):
592            if (hasattr (node.iter.func,'id')):
593                if (node.iter.func.id == 'range'):
594                    self.visit(node.target)
595                    iterator = self.current_statement
596                    self.current_statement = ''
597                    if (iterator not in self.C_IntVars):
598                        self.C_IntVars.append (iterator)
599                    start, stop, step = self.GetForRange(node)
600                    self.write_c ("for (" + iterator + "=" + str(start) + \
601                                  " ; " + iterator + " < " + str(stop) + \
602                                  " ; " + iterator + " += " + str(step) + ") {")
603                    self.body_or_else(node)
604                    self.write_c ("}")
605                    fForDone = True
606        if (fForDone == False):
607            line_number = self.getNodeLineNo (node)
608            self.current_statement = ''
609            self.write_c('for ')
610            self.visit(node.target)
611            self.write_c(' in ')
612            self.visit(node.iter)
613            self.write_c(':')
614            errStr = "Conversion Error in function " + self.name + ", Line #" + str (line_number)
615            errStr += "\nPython for expression not supported: '" + self.current_statement + "'"
616            raise Exception(errStr)
617
618    def visit_While(self, node):
619        self.newline(node)
620        self.write_c('while ')
621        self.visit(node.test)
622        self.write_c(':')
623        self.body_or_else(node)
624
625    def visit_With(self, node):
626        self.newline(node)
627        self.write_python('with ')
628        self.visit(node.context_expr)
629        if node.optional_vars is not None:
630            self.write_python(' as ')
631            self.visit(node.optional_vars)
632        self.write_python(':')
633        self.body(node.body)
634
635    def visit_Pass(self, node):
636        self.newline(node)
637        self.write_python('pass')
638
639    def visit_Print(self, node):
640# XXX: python 2.6 only
641        self.newline(node)
642        self.write_c('print ')
643        want_comma = False
644        if node.dest is not None:
645            self.write_c(' >> ')
646            self.visit(node.dest)
647            want_comma = True
648        for value in node.values:
649            if want_comma:
650                self.write_c(', ')
651            self.visit(value)
652            want_comma = True
653        if not node.nl:
654            self.write_c(',')
655
656    def visit_Delete(self, node):
657        self.newline(node)
658        self.write_python('del ')
659        for idx, target in enumerate(node):
660            if idx:
661                self.write_python(', ')
662            self.visit(target)
663
664    def visit_TryExcept(self, node):
665        self.newline(node)
666        self.write_python('try:')
667        self.body(node.body)
668        for handler in node.handlers:
669            self.visit(handler)
670
671    def visit_TryFinally(self, node):
672        self.newline(node)
673        self.write_python('try:')
674        self.body(node.body)
675        self.newline(node)
676        self.write_python('finally:')
677        self.body(node.finalbody)
678
679    def visit_Global(self, node):
680        self.newline(node)
681        self.write_python('global ' + ', '.join(node.names))
682
683    def visit_Nonlocal(self, node):
684        self.newline(node)
685        self.write_python('nonlocal ' + ', '.join(node.names))
686
687    def visit_Return(self, node):
688        self.newline(node)
689        if node.value is None:
690            self.write_c('return')
691        else:
692            self.write_c('return (')
693            self.visit(node.value)
694        self.write_c(');')
695#      self.add_current_statement (self)
696        self.add_c_line (self.current_statement)
697        self.current_statement = ''
698
699    def visit_Break(self, node):
700        self.newline(node)
701        self.write_c('break')
702
703    def visit_Continue(self, node):
704        self.newline(node)
705        self.write_c('continue')
706
707    def visit_Raise(self, node):
708        # XXX: Python 2.6 / 3.0 compatibility
709        self.newline(node)
710        self.write_python('raise')
711        if hasattr(node, 'exc') and node.exc is not None:
712            self.write_python(' ')
713            self.visit(node.exc)
714            if node.cause is not None:
715                self.write_python(' from ')
716                self.visit(node.cause)
717        elif hasattr(node, 'type') and node.type is not None:
718            self.visit(node.type)
719            if node.inst is not None:
720                self.write_python(', ')
721                self.visit(node.inst)
722            if node.tback is not None:
723                self.write_python(', ')
724                self.visit(node.tback)
725
726    # Expressions
727
728    def visit_Attribute(self, node):
729        errStr = "Conversion Error in function " + self.name + ", Line #" + str (node.value.lineno)
730        errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'"
731        raise Exception(errStr)
732        self.visit(node.value)
733        self.write_python('.' + node.attr)
734
735    def visit_Call(self, node):
736        want_comma = []
737        def write_comma():
738            if want_comma:
739                self.write_c(', ')
740            else:
741                want_comma.append(True)
742        if (hasattr (node.func, 'id')):
743            if (node.func.id not in self.C_Functions):
744                self.C_Functions.append (node.func.id)
745            if (node.func.id == 'abs'):
746                self.write_c ("fabs ")
747            elif (node.func.id == 'int'):
748                self.write_c('(int) ')
749            elif (node.func.id == "SINCOS"):
750                self.WriteSincos (node)
751                return
752            else:
753                self.visit(node.func)
754        else:
755            self.visit(node.func)
756#self.C_Functions
757        self.write_c('(')
758        for arg in node.args:
759            write_comma()
760            self.visited_args = True 
761            self.visit(arg)
762        for keyword in node.keywords:
763            write_comma()
764            self.write_c(keyword.arg + '=')
765            self.visit(keyword.value)
766        if hasattr (node, 'starargs'):
767            if node.starargs is not None:
768                write_comma()
769                self.write_c('*')
770                self.visit(node.starargs)
771        if hasattr (node, 'kwargs'):
772            if node.kwargs is not None:
773                write_comma()
774                self.write_c('**')
775                self.visit(node.kwargs)
776        self.write_c(');')
777
778    def visit_Name(self, node):
779        self.write_c(node.id)
780        if ((node.id in self.C_Pointers) and (not self.SubRef)):
781            self.write_c("[0]")
782        name = ""
783        sub = node.id.find("[")
784        if (sub > 0):
785            name = node.id[0:sub].strip()
786        else:
787            name = node.id
788#       add variable to C_Vars if it ins't there yet, not an argument and not a number
789        if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)):
790            if (self.InSubscript):
791                self.C_IntVars.append (node.id)
792            else:
793                self.C_Vars.append (node.id)
794
795    def visit_Str(self, node):
796        self.write_c(repr(node.s))
797
798    def visit_Bytes(self, node):
799        self.write_c(repr(node.s))
800
801    def visit_Num(self, node):
802        self.write_c(repr(node.n))
803
804    def visit_Tuple(self, node):
805        for idx, item in enumerate(node.elts):
806            if idx:
807                self.Tuples.append(item)
808            else:
809                self.visit(item)
810
811    def sequence_visit(left, right):
812        def visit(self, node):
813            self.is_sequence = True
814            s = ""
815            for idx, item in enumerate(node.elts):
816                if ((idx > 0) and (len(s) > 0)):
817                    s += ', '
818                if (hasattr (item, 'id')):
819                    s += item.id
820                elif (hasattr(item,'n')):
821                    s += str(item.n)
822            if (len(s) > 0):
823                self.C_Vectors.append (s)
824                vec_name = "vec"  + str(len(self.C_Vectors))
825                self.write_c(vec_name)
826                vec_name += "#"
827        return visit
828
829    visit_List = sequence_visit('[', ']')
830    visit_Set = sequence_visit('{', '}')
831    del sequence_visit
832
833    def visit_Dict(self, node):
834        self.write_python('{')
835        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
836            if idx:
837                self.write_python(', ')
838            self.visit(key)
839            self.write_python(': ')
840            self.visit(value)
841        self.write_python('}')
842
843    def get_special_power (self, string):
844        function_name = ''
845        is_negative_exp = False
846        if (isevaluable(str(self.current_statement))):
847            exponent = eval(string)
848            is_negative_exp = exponent < 0
849            abs_exponent = abs(exponent)
850            if (abs_exponent == 2):
851                function_name = "square"
852            elif (abs_exponent == 3):
853                function_name = "cube"
854            elif (abs_exponent == 0.5):
855                function_name = "sqrt"
856            elif (abs_exponent == 1.0/3.0):
857                function_name = "cbrt"
858        if (function_name == ''):
859            function_name = "pow"
860        return function_name, is_negative_exp
861
862    def translate_power (self, node):
863# get exponent by visiting the right hand argument.
864        function_name = "pow"
865        temp_statement = self.current_statement
866# 'visit' functions write the results to the 'current_statement' class memnber
867# Here, a temporary variable, 'temp_statement', is used, that enables the
868# use of the 'visit' function
869        self.current_statement = ''
870        self.visit(node.right)
871        exponent = self.current_statement.replace(' ','')
872        function_name, is_negative_exp = self.get_special_power (self.current_statement)
873        self.current_statement = temp_statement
874        if (is_negative_exp):
875            self.write_c ("1.0 / (")
876        self.write_c (function_name + " (")
877        self.visit(node.left)
878        if (function_name == "pow"):
879            self.write_c(", ")
880            self.visit(node.right)
881        self.write_c(")")
882        if (is_negative_exp):
883            self.write_c(")")
884        self.write_c(" ")
885
886    def translate_integer_divide (self, node):
887        self.write_c ("(int) (")
888        self.visit(node.left)
889        self.write_c (") / (int) (")
890        self.visit(node.right)
891        self.write_c (")")
892
893    def visit_BinOp(self, node):
894        if ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]):
895            self.translate_power (node)
896        elif ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]):
897            self.translate_integer_divide (node)
898        else:
899            self.visit(node.left)
900            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
901            self.visit(node.right)
902#       for C
903    def visit_BoolOp(self, node):
904        self.write_c('(')
905        for idx, value in enumerate(node.values):
906            if idx:
907                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
908            self.visit(value)
909        self.write_c(')')
910
911    def visit_Compare(self, node):
912        self.write_c('(')
913        self.visit(node.left)
914        for op, right in zip(node.ops, node.comparators):
915            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
916            self.visit(right)
917        self.write_c(')')
918
919    def visit_UnaryOp(self, node):
920        self.write_c('(')
921        op = UNARYOP_SYMBOLS[type(node.op)]
922        self.write_c(op)
923        if op == 'not':
924            self.write_c(' ')
925        self.visit(node.operand)
926        self.write_c(')')
927
928    def visit_Subscript(self, node):
929        if (node.value.id not in self.C_Pointers):
930            self.C_Pointers.append (node.value.id)
931        self.SubRef = True
932        self.visit(node.value)
933        self.SubRef = False
934        self.write_c('[')
935        self.InSubscript = True
936        self.visit(node.slice)
937        self.InSubscript = False
938        self.write_c(']')
939
940    def visit_Slice(self, node):
941        if node.lower is not None:
942            self.visit(node.lower)
943        self.write_python(':')
944        if node.upper is not None:
945            self.visit(node.upper)
946        if node.step is not None:
947            self.write_python(':')
948            if not (isinstance(node.step, Name) and node.step.id == 'None'):
949                self.visit(node.step)
950
951    def visit_ExtSlice(self, node):
952        for idx, item in node.dims:
953            if idx:
954                self.write_python(', ')
955            self.visit(item)
956
957    def visit_Yield(self, node):
958        self.write_python('yield ')
959        self.visit(node.value)
960
961    def visit_Lambda(self, node):
962        self.write_python('lambda ')
963        self.visit(node.args)
964        self.write_python(': ')
965        self.visit(node.body)
966
967    def visit_Ellipsis(self, node):
968        self.write_python('Ellipsis')
969
970    def generator_visit(left, right):
971        def visit(self, node):
972            self.write_python(left)
973            self.write_c(left)
974            self.visit(node.elt)
975            for comprehension in node.generators:
976                self.visit(comprehension)
977            self.write_c(right)
978#            self.write_python(right)
979        return visit
980
981    visit_ListComp = generator_visit('[', ']')
982    visit_GeneratorExp = generator_visit('(', ')')
983    visit_SetComp = generator_visit('{', '}')
984    del generator_visit
985
986    def visit_DictComp(self, node):
987        self.write_python('{')
988        self.visit(node.key)
989        self.write_python(': ')
990        self.visit(node.value)
991        for comprehension in node.generators:
992            self.visit(comprehension)
993        self.write_python('}')
994
995    def visit_IfExp(self, node):
996        self.visit(node.body)
997        self.write_c(' if ')
998        self.visit(node.test)
999        self.write_c(' else ')
1000        self.visit(node.orelse)
1001
1002    def visit_Starred(self, node):
1003        self.write_c('*')
1004        self.visit(node.value)
1005
1006    def visit_Repr(self, node):
1007        # XXX: python 2.6 only
1008        self.write_c('`')
1009        self.visit(node.value)
1010        self.write_python('`')
1011
1012    # Helper Nodes
1013
1014    def visit_alias(self, node):
1015        self.write_python(node.name)
1016        if node.asname is not None:
1017            self.write_python(' as ' + node.asname)
1018
1019    def visit_comprehension(self, node):
1020        self.write_c(' for ')
1021        self.visit(node.target)
1022        self.write_C(' in ')
1023#        self.write_python(' in ')
1024        self.visit(node.iter)
1025        if node.ifs:
1026            for if_ in node.ifs:
1027                self.write_python(' if ')
1028                self.visit(if_)
1029
1030#    def visit_excepthandler(self, node):
1031#        self.newline(node)
1032#        self.write_python('except')
1033#        if node.type is not None:
1034#            self.write_python(' ')
1035#            self.visit(node.type)
1036#            if node.name is not None:
1037#                self.write_python(' as ')
1038#                self.visit(node.name)
1039#        self.body(node.body)
1040
1041    def visit_arguments(self, node):
1042        self.signature(node)
1043
1044def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1045    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
1046    if (q > p):
1047        q = p + 17
1048        p = q - 5
1049    z3 = -8
1050    inten = (porod_scale / q ** porod_exp
1051                + lorentz_scale / (1 + z ** lorentz_exp))
1052    return inten
1053
1054def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1055    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
1056    if (q > p):
1057        q = p + 17
1058        p = q - 5
1059    elif (q == p):
1060        q = p * q
1061        q *= z1
1062        p = z1
1063    elif (q == 17):
1064        q = p * q - 17
1065    else:
1066        q += 7
1067    z3 = -8
1068    inten = (porod_scale / q ** porod_exp
1069                + lorentz_scale / (1 + z ** lorentz_exp))
1070    return inten
1071
1072def print_function(f=None):
1073    """
1074    Print out the code for the function
1075    """
1076    # Include some comments to see if they get printed
1077    import ast
1078    import inspect
1079    if f is not None:
1080        tree = ast.parse(inspect.getsource(f))
1081        tree_source = to_source (tree)
1082        print(tree_source)
1083
1084def translate (functions, constants=0):
1085    sniplets = []
1086    fname = functions[1]
1087    python_file = open (fname, "r")
1088    source = python_file.read()
1089    python_file.close()
1090    tree = ast.parse (source)
1091    sniplet = to_source (tree, functions) # in the future add filename, offset, constants
1092    sniplets.append(sniplet)
1093    return ("\n".join(sniplets))
1094
1095def get_file_names ():
1096    fname_in = ""
1097    fname_out = ""
1098    if (len(sys.argv) > 1):
1099        fname_in = sys.argv[1]
1100        fname_base = os.path.splitext (fname_in)
1101        if (len (sys.argv) == 2):
1102            fname_out = str (fname_base[0]) + '.c'
1103        else:
1104            fname_out = sys.argv[2]
1105        if (len (fname_in) > 0):
1106            python_file = open (sys.argv[1], "r")
1107            if (len (fname_out) > 0):
1108                file_out = open (fname_out, "w+")
1109    return len(sys.argv), fname_in, fname_out
1110
1111if __name__ == "__main__":
1112    import os
1113    print("Parsing...using Python" + sys.version)
1114    try:
1115        fname_in = ""
1116        fname_out = ""
1117        if (len (sys.argv) == 1):
1118            print ("Usage:\npython parse01.py <infile> [<outfile>] (if omitted, output file is '<infile>.c'")
1119        else:
1120            fname_in = sys.argv[1]
1121            fname_base = os.path.splitext (fname_in)
1122            if (len (sys.argv) == 2):
1123                fname_out = str (fname_base[0]) + '.c'
1124            else:
1125                fname_out = sys.argv[2]
1126            if (len (fname_in) > 0):
1127                python_file = open (sys.argv[1], "r")
1128                if (len (fname_out) > 0):
1129                    file_out = open (fname_out, "w+")
1130                functions = ["MultAsgn", "Iq41", "Iq2"]
1131                tpls = [functions, fname_in, 0]
1132                c_txt = translate (tpls)
1133                file_out.write (c_txt)
1134                file_out.close()
1135    except Exception as excp:
1136        print ("Error:\n" + str(excp.args))
1137    print("...Done")
Note: See TracBrowser for help on using the repository browser.