source: sasmodels/sasmodels/py2c.py @ 3f9db6e

Last change on this file since 3f9db6e was 3f9db6e, checked in by Omer Eisenberg <omereis@…>, 6 years ago

supporting all sorts of special power: ± 1/3, ½, 2, 3

  • Property mode set to 100644
File size: 36.1 KB
Line 
1
2"""
3    codegen
4    ~~~~~~~
5
6    Extension to ast that allow ast -> python code generation.
7
8    :copyright: Copyright 2008 by Armin Ronacher.
9    :license: BSD.
10"""
11"""
12    Variables definition in C
13    -------------------------
14    Defining variables within the Translate function is a bit of a guess work, using following rules.
15    *   By default, a variable is a 'double'.
16    *   Variable in a for loop is an int.
17    *   Variable that is references with brackets is an array of doubles. The variable within the brackets
18        is integer. For example, in the reference 'var1[var2]', var1 is a double array, and var2 is an integer.
19    *   Assignment to an argument makes that argument an array, and the index in that assignment is 0.
20        For example, the following python code
21            def func (arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code
24            double func (double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the following C code
29            def func (arg1, arg2):          double func (double arg1) {
30                arg2 = 17.                      arg2[0] = 17.0;
31                                            }
32    *   All functions are defined as double, even if there is no return statement.
33
34
35Update Notes
36============
3711/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It shold insert 4 blanks per indentation level.
38                    The 'body' method will combine all the strings, by adding the 'current_statement' to the c_proc string list
39   11/2017, OE: variables, argument definition implemented.
40   Note: An argument is considered an array if it is the target of an assignment. In that case it is translated to <var>[0]
4111/27/2017, OE: 'pow' basicly working
42  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
43  /12/2017, OE: Power function, including special cases of square(x) (pow(x,2)) and cube(x) (pow(x,3)), implemented in translate_power,
44                called from visit_BinOp
4512/07/2017, OE: Translation of integer division, '\\' in python, implemented in translate_integer_divide, called from visit_BinOp
4612/07/2017, OE: C variable definition handled in 'define_C_Vars'
47"""
48import ast
49import sys
50from ast import NodeVisitor
51
52BINOP_SYMBOLS = {}
53BINOP_SYMBOLS[ast.Add] = '+'
54BINOP_SYMBOLS[ast.Sub] = '-'
55BINOP_SYMBOLS[ast.Mult] = '*'
56BINOP_SYMBOLS[ast.Div] = '/'
57BINOP_SYMBOLS[ast.Mod] = '%'
58BINOP_SYMBOLS[ast.Pow] = '**'
59BINOP_SYMBOLS[ast.LShift] = '<<'
60BINOP_SYMBOLS[ast.RShift] = '>>'
61BINOP_SYMBOLS[ast.BitOr] = '|'
62BINOP_SYMBOLS[ast.BitXor] = '^'
63BINOP_SYMBOLS[ast.BitAnd] = '&'
64BINOP_SYMBOLS[ast.FloorDiv] = '//'
65
66BOOLOP_SYMBOLS = {}
67BOOLOP_SYMBOLS[ast.And] = '&&'
68BOOLOP_SYMBOLS[ast.Or]  = '||'
69
70CMPOP_SYMBOLS = {}
71CMPOP_SYMBOLS[ast.Eq] = '=='
72CMPOP_SYMBOLS[ast.NotEq] = '!='
73CMPOP_SYMBOLS[ast.Lt] = '<'
74CMPOP_SYMBOLS[ast.LtE] = '<='
75CMPOP_SYMBOLS[ast.Gt] = '>'
76CMPOP_SYMBOLS[ast.GtE] = '>='
77CMPOP_SYMBOLS[ast.Is] = 'is'
78CMPOP_SYMBOLS[ast.IsNot] = 'is not'
79CMPOP_SYMBOLS[ast.In] = 'in'
80CMPOP_SYMBOLS[ast.NotIn] = 'not in'
81
82UNARYOP_SYMBOLS = {}
83UNARYOP_SYMBOLS[ast.Invert] = '~'
84UNARYOP_SYMBOLS[ast.Not] = 'not'
85UNARYOP_SYMBOLS[ast.UAdd] = '+'
86UNARYOP_SYMBOLS[ast.USub] = '-'
87
88
89#def to_source(node, indent_with=' ' * 4, add_line_information=False):
90def to_source(node, func_name):
91    """This function can convert a node tree back into python sourcecode.
92    This is useful for debugging purposes, especially if you're dealing with
93    custom asts not generated by python itself.
94
95    It could be that the sourcecode is evaluable when the AST itself is not
96    compilable / evaluable.  The reason for this is that the AST contains some
97    more data than regular sourcecode does, which is dropped during
98    conversion.
99
100    Each level of indentation is replaced with `indent_with`.  Per default this
101    parameter is equal to four spaces as suggested by PEP 8, but it might be
102    adjusted to match the application's styleguide.
103
104    If `add_line_information` is set to `True` comments for the line numbers
105    of the nodes are added to the output.  This can be used to spot wrong line
106    number information of statement nodes.
107    """
108    generator = SourceGenerator(' ' * 4, False)
109    generator.required_functions = func_name
110    generator.visit(node)
111
112#    return ''.join(generator.result)
113    return ''.join(generator.c_proc)
114
115def isevaluable(s):
116    try:
117        eval(s)
118        return True
119    except:
120        return False
121   
122class SourceGenerator(NodeVisitor):
123    """This visitor is able to transform a well formed syntax tree into python
124    sourcecode.  For more details have a look at the docstring of the
125    `node_to_source` function.
126    """
127
128    def __init__(self, indent_with, add_line_information=False):
129        self.result = []
130        self.indent_with = indent_with
131        self.add_line_information = add_line_information
132        self.indentation = 0
133        self.new_lines = 0
134        self.c_proc = []
135# for C
136        self.signature_line = 0
137        self.arguments = []
138        self.name = ""
139        self.warnings = []
140        self.Statements = []
141        self.current_statement = ""
142        self.strMethodSignature = ""
143        self.C_Vars = []
144        self.C_IntVars = []
145        self.MathIncludeed = False
146        self.C_Pointers = []
147        self.C_DclPointers = []
148        self.C_Functions = []
149        self.C_Vectors = []
150        self.SubRef = False
151        self.InSubscript = False
152        self.Tuples = []
153        self.required_functions = []
154        self.is_sequence = False
155        self.visited_args = False
156
157    def write_python(self, x):
158        if self.new_lines:
159            if self.result:
160                self.result.append('\n' * self.new_lines)
161            self.result.append(self.indent_with * self.indentation)
162            self.new_lines = 0
163        self.result.append(x)
164
165    def write_c(self, x):
166        self.current_statement += x
167       
168    def add_c_line(self, x):
169        string = ''
170        for i in range (self.indentation):
171            string += ("    ")
172        string += str(x)
173        self.c_proc.append (str(string + "\n"))
174        x = ''
175
176    def add_current_line (self):
177        if (len(self.current_statement) > 0):
178            self.add_c_line (self.current_statement)
179            self.current_statement = ''
180
181    def AddUniqueVar (self, new_var):
182        if ((new_var not in self.C_Vars)):
183            self.C_Vars.append (str (new_var))
184
185    def WriteSincos (self, node):
186        angle = str(node.args[0].id)
187        self.write_c (node.args[1].id + " = sin (" + angle + ");")
188        self.add_current_line()
189        self.write_c (node.args[2].id + " = cos (" + angle + ");")
190        self.add_current_line()
191        for arg in node.args:
192            self.AddUniqueVar (arg.id)
193
194    def newline(self, node=None, extra=0):
195        self.new_lines = max(self.new_lines, 1 + extra)
196        if node is not None and self.add_line_information:
197            self.write_c('# line: %s' % node.lineno)
198            self.new_lines = 1
199        if (len(self.current_statement)):
200            self.Statements.append (self.current_statement)
201            self.current_statement = ''
202
203    def body(self, statements):
204        if (len(self.current_statement)):
205            self.add_current_line ()
206        self.new_line = True
207        self.indentation += 1
208        for stmt in statements:
209            target_name = ''
210            if (hasattr (stmt, 'targets')):
211                if (hasattr(stmt.targets[0],'id')):
212                    target_name = stmt.targets[0].id # target name needed for debug only
213            self.visit(stmt)
214        self.add_current_line () # just for breaking point. to be deleted.
215        self.indentation -= 1
216
217    def body_or_else(self, node):
218        self.body(node.body)
219        if node.orelse:
220            self.newline()
221            self.write_c('else:')
222            self.body(node.orelse)
223
224    def signature(self, node):
225        want_comma = []
226        def write_comma():
227            if want_comma:
228                self.write_c(', ')
229            else:
230                want_comma.append(True)
231# for C
232        for arg in node.args:
233            self.arguments.append (arg.arg)
234
235        padding = [None] * (len(node.args) - len(node.defaults))
236        for arg, default in zip(node.args, padding + node.defaults):
237            if default is not None:
238                self.warnings.append ("Default Parameter unknown to C")
239                w_str = "Default Parameters are unknown to C: '" + arg.arg + " = " + str(default.n) + "'"
240                self.warnings.append (w_str)
241#                self.write_python('=')
242#                self.visit(default)
243
244    def decorators(self, node):
245        for decorator in node.decorator_list:
246            self.newline(decorator)
247            self.write_python('@')
248            self.visit(decorator)
249
250    # Statements
251
252    def visit_Assert(self, node):
253        self.newline(node)
254        self.write_c('assert ')
255        self.visit(node.test)
256        if node.msg is not None:
257           self.write_python(', ')
258           self.visit(node.msg)
259
260    def define_C_Vars (self, target):
261        if (hasattr (target, 'id')):
262# a variable is considered an array if it apears in the agrument list
263# and being assigned to. For example, the variable p in the following
264# sniplet is a pointer, while q is not
265# def somefunc (p, q):
266#  p = q + 1
267#  return
268#
269            if (target.id not in self.C_Vars):
270                if (target.id in self.arguments):
271                    idx = self.arguments.index (target.id)
272                    new_target = self.arguments[idx] + "[0]"
273                    if (new_target not in self.C_Pointers):
274                        target.id = new_target
275                        self.C_Pointers.append (self.arguments[idx])
276                else:
277                    self.C_Vars.append (target.id)
278
279    def visit_Assign(self, node):
280        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
281            if idx:
282                self.write_c(' = ')
283            self.define_C_Vars (target)
284            self.visit(target)
285        if (len(self.Tuples) > 0):
286            tplTargets = list (self.Tuples)
287            self.Tuples.clear()
288        self.write_c(' = ')
289        self.is_sequence = False
290        self.visited_args = False
291        self.visit(node.value)
292        self.write_c(';')
293        self.add_current_line ()
294        for n,item in enumerate (self.Tuples):
295            self.visit(tplTargets[n])
296            self.write_c(' = ')
297            self.visit(item)
298            self.write_c(';')
299            self.add_current_line ()
300        if ((self.is_sequence) and (not self.visited_args)):
301            for target  in node.targets:
302                if (hasattr (target, 'id')):
303                    if ((target.id in self.C_Vars) and (target.id not in self.C_DclPointers)):
304                        if (target.id not in self.C_DclPointers):
305                            self.C_DclPointers.append (target.id)
306                            if (target.id in self.C_Vars):
307                                self.C_Vars.remove(target.id)
308        self.current_statement = ''
309
310    def visit_AugAssign(self, node):
311        if (node.target.id not in self.C_Vars):
312            if (node.target.id not in self.arguments):
313                self.C_Vars.append (node.target.id)
314        self.visit(node.target)
315        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
316        self.visit(node.value)
317        self.write_c(';')
318        self.add_current_line ()
319
320    def visit_ImportFrom(self, node):
321        self.newline(node)
322        self.write_python('from %s%s import ' % ('.' * node.level, node.module))
323        for idx, item in enumerate(node.names):
324            if idx:
325                self.write_python(', ')
326            self.write_python(item)
327
328    def visit_Import(self, node):
329        self.newline(node)
330        for item in node.names:
331            self.write_python('import ')
332            self.visit(item)
333
334    def visit_Expr(self, node):
335        self.newline(node)
336        self.generic_visit(node)
337
338    def listToDeclare(self, Vars):
339        s = ''
340        if (len(Vars) > 0):
341            s = ",".join (Vars)
342        return (s)
343
344    def insert_C_Vars (self, start_var):
345        fLine = False
346        if (len(self.C_Vars) > 0):
347            s = self.listToDeclare(self.C_Vars)
348            self.c_proc.insert (start_var, "    double " + s + ";\n")
349            fLine = True
350            start_var += 1
351        if (len(self.C_IntVars) > 0):
352            s = self.listToDeclare(self.C_IntVars)
353            self.c_proc.insert (start_var, "    int " + s + ";\n")
354            fLine = True
355            start_var += 1
356        if (len (self.C_Vectors) > 0):
357            s = self.listToDeclare(self.C_Vectors)
358            for n in range(len(self.C_Vectors)):
359                name = "vec" + str(n+1)
360                c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};"
361                self.c_proc.insert (start_var, c_dcl + "\n")
362                start_var += 1
363        if (len (self.C_Pointers) > 0):
364            vars = ""
365            for n in range(len(self.C_Pointers)):
366                if (len(vars) > 0):
367                    vars += ", "
368                if (self.C_Pointers[n] not in self.arguments):
369                    vars += "*" + self.C_Pointers[n]
370            if (len(vars) > 0):
371                c_dcl = "    double " + vars + ";"
372                self.c_proc.insert (start_var, c_dcl + "\n")
373                start_var += 1
374        if (len (self.C_DclPointers) > 0):
375            vars = ''
376            for n in range(len(self.C_DclPointers)):
377                if (len(vars) > 0):
378                    vars += ', '
379                vars += "*" + self.C_DclPointers[n]
380            if (len(vars) > 0):
381                c_dcl = "    double " + vars + ";"
382                self.c_proc.insert (start_var, c_dcl + "\n")
383                start_var += 1
384        self.C_Vars.clear()
385        self.C_IntVars.clear()
386        self.C_Vectors.clear()
387        self.C_Pointers.clear()
388        if (fLine == True):
389            self.c_proc.insert (start_var, "\n")
390        return
391        s = ''
392        for n in range(len(self.C_Vars)):
393            s += str(self.C_Vars[n])
394            if n < len(self.C_Vars) - 1:
395                s += ", "
396        if (len(s) > 0):
397            self.c_proc.insert (start_var, "    double " + s + ";\n")
398            self.c_proc.insert (start_var + 1, "\n")
399
400    def writeInclude (self):
401        if (self.MathIncludeed == False):
402            self.add_c_line("#include <math.h>\n")
403            self.add_c_line("static double pi = 3.14159265359;\n")
404            self.MathIncludeed = True
405
406    def ListToString (self, strings):
407        s = ''
408        for n in range(len(strings)):
409            s += strings[n]
410            if (n < (len(strings) - 1)):
411                s += ", "
412        return (s)
413
414    def getMethodSignature (self):
415#        args_str = ListToString (self.arguments)
416        args_str = ''
417        for n in range(len(self.arguments)):
418            args_str += "double " + self.arguments[n]
419            if (n < (len(self.arguments) - 1)):
420                args_str += ", "
421        return (args_str)
422#        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
423
424    def InsertSignature (self):
425        args_str = ''
426        for n in range(len(self.arguments)):
427            args_str += "double " + self.arguments[n]
428            if (self.arguments[n] in self.C_Pointers):
429                args_str += "[]"
430            if (n < (len(self.arguments) - 1)):
431                args_str += ", "
432        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
433        if (self.signature_line > 0):
434            self.c_proc.insert (self.signature_line, self.strMethodSignature)
435
436    def visit_FunctionDef(self, node):
437        self.newline(extra=1)
438        self.decorators(node)
439        self.newline(node)
440        self.arguments = []
441        self.name = node.name
442        if self.name not in self.required_functions[0]:
443            return
444        print("Parsing '" + self.name + "'")
445        args_str = ""
446
447        self.visit(node.args)
448# for C
449        self.writeInclude()
450        self.getMethodSignature ()
451# for C
452        self.signature_line = len(self.c_proc)
453#        self.add_c_line(self.strMethodSignature)
454        self.add_c_line("\n{")
455        start_vars = len(self.c_proc) + 1
456        self.body(node.body)
457        self.add_c_line("}\n")
458        self.InsertSignature ()
459        self.insert_C_Vars (start_vars)
460        self.C_Pointers = []
461
462    def visit_ClassDef(self, node):
463        have_args = []
464        def paren_or_comma():
465            if have_args:
466                self.write_python(', ')
467            else:
468                have_args.append(True)
469                self.write_python('(')
470
471        self.newline(extra=2)
472        self.decorators(node)
473        self.newline(node)
474        self.write_python('class %s' % node.name)
475        for base in node.bases:
476            paren_or_comma()
477            self.visit(base)
478        # XXX: the if here is used to keep this module compatible
479        #      with python 2.6.
480        if hasattr(node, 'keywords'):
481            for keyword in node.keywords:
482                paren_or_comma()
483                self.write_python(keyword.arg + '=')
484                self.visit(keyword.value)
485            if node.starargs is not None:
486                paren_or_comma()
487                self.write_python('*')
488                self.visit(node.starargs)
489            if node.kwargs is not None:
490                paren_or_comma()
491                self.write_python('**')
492                self.visit(node.kwargs)
493        self.write_python(have_args and '):' or ':')
494        self.body(node.body)
495
496    def visit_If(self, node):
497        self.write_c('if ')
498        self.visit(node.test)
499        self.write_c(' {')
500        self.body(node.body)
501        self.add_c_line('}')
502        while True:
503            else_ = node.orelse
504            if len(else_) == 0:
505                break
506#            elif hasattr (else_, 'orelse'):
507            elif len(else_) == 1 and isinstance(else_[0], ast.If):
508                node = else_[0]
509#                self.newline()
510                self.write_c('else if ')
511                self.visit(node.test)
512                self.write_c(' {')
513                self.body(node.body)
514                self.add_current_line ()
515                self.add_c_line('}')
516#                break
517            else:
518                self.newline()
519                self.write_c('else {')
520                self.body(node.body)
521                self.add_c_line('}')
522                break
523
524    def getNodeLineNo (self, node):
525        line_number = -1
526        if (hasattr (node,'value')):
527            line_number = node.value.lineno
528        elif hasattr(node,'iter'):
529            if hasattr(node.iter,'lineno'):
530                line_number = node.iter.lineno
531        return (line_number)
532
533    def visit_For(self, node):
534        if (len(self.current_statement) > 0):
535            self.add_c_line(self.current_statement)
536            self.current_statement = ''
537        fForDone = False
538        if (hasattr(node.iter,'func')):
539            if (hasattr (node.iter.func,'id')):
540                if (node.iter.func.id == 'range'):
541                    if ('n' not in self.C_IntVars):
542                        self.C_IntVars.append ('n')
543                    self.write_c ("for (n=0 ; n < len(")
544                    for arg in node.iter.args:
545                        self.visit(arg)
546                    self.write_c (") ; n++) {")
547                    self.body_or_else(node)
548                    self.write_c ("}")
549                    fForDone = True
550        if (fForDone == False):
551            line_number = self.getNodeLineNo (node)
552            self.current_statement = ''
553            self.write_c('for ')
554            self.visit(node.target)
555            self.write_c(' in ')
556            self.visit(node.iter)
557            self.write_c(':')
558            errStr = "Conversion Error in function " + self.name + ", Line #" + str (line_number)
559            errStr += "\nPython for expression not supported: '" + self.current_statement + "'"
560            raise Exception(errStr)
561
562    def visit_While(self, node):
563        self.newline(node)
564        self.write_c('while ')
565        self.visit(node.test)
566        self.write_c(':')
567        self.body_or_else(node)
568
569    def visit_With(self, node):
570        self.newline(node)
571        self.write_python('with ')
572        self.visit(node.context_expr)
573        if node.optional_vars is not None:
574            self.write_python(' as ')
575            self.visit(node.optional_vars)
576        self.write_python(':')
577        self.body(node.body)
578
579    def visit_Pass(self, node):
580        self.newline(node)
581        self.write_python('pass')
582
583    def visit_Print(self, node):
584# XXX: python 2.6 only
585        self.newline(node)
586        self.write_c('print ')
587        want_comma = False
588        if node.dest is not None:
589            self.write_c(' >> ')
590            self.visit(node.dest)
591            want_comma = True
592        for value in node.values:
593            if want_comma:
594                self.write_c(', ')
595            self.visit(value)
596            want_comma = True
597        if not node.nl:
598            self.write_c(',')
599
600    def visit_Delete(self, node):
601        self.newline(node)
602        self.write_python('del ')
603        for idx, target in enumerate(node):
604            if idx:
605                self.write_python(', ')
606            self.visit(target)
607
608    def visit_TryExcept(self, node):
609        self.newline(node)
610        self.write_python('try:')
611        self.body(node.body)
612        for handler in node.handlers:
613            self.visit(handler)
614
615    def visit_TryFinally(self, node):
616        self.newline(node)
617        self.write_python('try:')
618        self.body(node.body)
619        self.newline(node)
620        self.write_python('finally:')
621        self.body(node.finalbody)
622
623    def visit_Global(self, node):
624        self.newline(node)
625        self.write_python('global ' + ', '.join(node.names))
626
627    def visit_Nonlocal(self, node):
628        self.newline(node)
629        self.write_python('nonlocal ' + ', '.join(node.names))
630
631    def visit_Return(self, node):
632        self.newline(node)
633        if node.value is None:
634            self.write_c('return')
635        else:
636            self.write_c('return (')
637            self.visit(node.value)
638        self.write_c(');')
639#      self.add_current_statement (self)
640        self.add_c_line (self.current_statement)
641        self.current_statement = ''
642
643    def visit_Break(self, node):
644        self.newline(node)
645        self.write_c('break')
646
647    def visit_Continue(self, node):
648        self.newline(node)
649        self.write_c('continue')
650
651    def visit_Raise(self, node):
652        # XXX: Python 2.6 / 3.0 compatibility
653        self.newline(node)
654        self.write_python('raise')
655        if hasattr(node, 'exc') and node.exc is not None:
656            self.write_python(' ')
657            self.visit(node.exc)
658            if node.cause is not None:
659                self.write_python(' from ')
660                self.visit(node.cause)
661        elif hasattr(node, 'type') and node.type is not None:
662            self.visit(node.type)
663            if node.inst is not None:
664                self.write_python(', ')
665                self.visit(node.inst)
666            if node.tback is not None:
667                self.write_python(', ')
668                self.visit(node.tback)
669
670    # Expressions
671
672    def visit_Attribute(self, node):
673        errStr = "Conversion Error in function " + self.name + ", Line #" + str (node.value.lineno)
674        errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'"
675        raise Exception(errStr)
676        self.visit(node.value)
677        self.write_python('.' + node.attr)
678
679    def visit_Call(self, node):
680        want_comma = []
681        def write_comma():
682            if want_comma:
683                self.write_c(', ')
684            else:
685                want_comma.append(True)
686        if (hasattr (node.func, 'id')):
687            if (node.func.id not in self.C_Functions):
688                self.C_Functions.append (node.func.id)
689            if (node.func.id == 'abs'):
690                self.write_c ("fabs ")
691            elif (node.func.id == 'int'):
692                self.write_c('(int) ')
693            elif (node.func.id == "SINCOS"):
694                self.WriteSincos (node)
695                return
696            else:
697                self.visit(node.func)
698        else:
699            self.visit(node.func)
700#self.C_Functions
701        self.write_c('(')
702        for arg in node.args:
703            write_comma()
704            self.visited_args = True 
705            self.visit(arg)
706        for keyword in node.keywords:
707            write_comma()
708            self.write_c(keyword.arg + '=')
709            self.visit(keyword.value)
710        if hasattr (node, 'starargs'):
711            if node.starargs is not None:
712                write_comma()
713                self.write_c('*')
714                self.visit(node.starargs)
715        if hasattr (node, 'kwargs'):
716            if node.kwargs is not None:
717                write_comma()
718                self.write_c('**')
719                self.visit(node.kwargs)
720        self.write_c(')')
721
722    def visit_Name(self, node):
723        self.write_c(node.id)
724        if ((node.id in self.C_Pointers) and (not self.SubRef)):
725            self.write_c("[0]")
726        name = ""
727        sub = node.id.find("[")
728        if (sub > 0):
729            name = node.id[0:sub].strip()
730        else:
731            name = node.id
732#       add variable to C_Vars if it ins't there yet, not an argument and not a number
733        if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)):
734            if (self.InSubscript):
735                self.C_IntVars.append (node.id)
736            else:
737                self.C_Vars.append (node.id)
738
739    def visit_Str(self, node):
740        self.write_c(repr(node.s))
741
742    def visit_Bytes(self, node):
743        self.write_c(repr(node.s))
744
745    def visit_Num(self, node):
746        self.write_c(repr(node.n))
747
748    def visit_Tuple(self, node):
749        for idx, item in enumerate(node.elts):
750            if idx:
751                self.Tuples.append(item)
752            else:
753                self.visit(item)
754
755    def sequence_visit(left, right):
756        def visit(self, node):
757            self.is_sequence = True
758            s = ""
759            for idx, item in enumerate(node.elts):
760                if ((idx > 0) and (len(s) > 0)):
761                    s += ', '
762                if (hasattr (item, 'id')):
763                    s += item.id
764                elif (hasattr(item,'n')):
765                    s += str(item.n)
766            if (len(s) > 0):
767                self.C_Vectors.append (s)
768                vec_name = "vec"  + str(len(self.C_Vectors))
769                self.write_c(vec_name)
770                vec_name += "#"
771        return visit
772
773    visit_List = sequence_visit('[', ']')
774    visit_Set = sequence_visit('{', '}')
775    del sequence_visit
776
777    def visit_Dict(self, node):
778        self.write_python('{')
779        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
780            if idx:
781                self.write_python(', ')
782            self.visit(key)
783            self.write_python(': ')
784            self.visit(value)
785        self.write_python('}')
786
787    def get_special_power (self, string):
788        function_name = ''
789        is_negative_exp = False
790        if (isevaluable(str(self.current_statement))):
791            exponent = eval(string)
792            is_negative_exp = exponent < 0
793            abs_exponent = abs(exponent)
794            if (abs_exponent == 2):
795                function_name = "square"
796            elif (abs_exponent == 3):
797                function_name = "cube"
798            elif (abs_exponent == 0.5):
799                function_name = "sqrt"
800            elif (abs_exponent == 1.0/3.0):
801                function_name = "cbrt"
802        if (function_name == ''):
803            function_name = "pow"
804        return function_name, is_negative_exp
805
806    def translate_power (self, node):
807# get exponent by visiting the right hand argument.
808        function_name = "pow"
809        temp_statement = self.current_statement
810# 'visit' functions write the results to the 'current_statement' class memnber
811# Here, a temporary variable, 'temp_statement', is used, that enables the
812# use of the 'visit' function
813        self.current_statement = ''
814        self.visit(node.right)
815        exponent = self.current_statement.replace(' ','')
816        function_name, is_negative_exp = self.get_special_power (self.current_statement)
817        self.current_statement = temp_statement
818        if (is_negative_exp):
819            self.write_c ("1.0 / (")
820        self.write_c (function_name + " (")
821        self.visit(node.left)
822        if (function_name == "pow"):
823            self.write_c(", ")
824            self.visit(node.right)
825        self.write_c(")")
826        if (is_negative_exp):
827            self.write_c(")")
828        self.write_c(" ")
829
830    def translate_integer_divide (self, node):
831        self.write_c ("(int) (")
832        self.visit(node.left)
833        self.write_c (") / (int) (")
834        self.visit(node.right)
835        self.write_c (")")
836
837    def visit_BinOp(self, node):
838        if ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]):
839            self.translate_power (node)
840        elif ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]):
841            self.translate_integer_divide (node)
842        else:
843            self.visit(node.left)
844            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
845            self.visit(node.right)
846#       for C
847    def visit_BoolOp(self, node):
848        self.write_c('(')
849        for idx, value in enumerate(node.values):
850            if idx:
851                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
852            self.visit(value)
853        self.write_c(')')
854
855    def visit_Compare(self, node):
856        self.write_c('(')
857        self.visit(node.left)
858        for op, right in zip(node.ops, node.comparators):
859            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
860            self.visit(right)
861        self.write_c(')')
862
863    def visit_UnaryOp(self, node):
864        self.write_c('(')
865        op = UNARYOP_SYMBOLS[type(node.op)]
866        self.write_c(op)
867        if op == 'not':
868            self.write_c(' ')
869        self.visit(node.operand)
870        self.write_c(')')
871
872    def visit_Subscript(self, node):
873        if (node.value.id not in self.C_Pointers):
874            self.C_Pointers.append (node.value.id)
875        self.SubRef = True
876        self.visit(node.value)
877        self.SubRef = False
878        self.write_c('[')
879        self.InSubscript = True
880        self.visit(node.slice)
881        self.InSubscript = False
882        self.write_c(']')
883
884    def visit_Slice(self, node):
885        if node.lower is not None:
886            self.visit(node.lower)
887        self.write_python(':')
888        if node.upper is not None:
889            self.visit(node.upper)
890        if node.step is not None:
891            self.write_python(':')
892            if not (isinstance(node.step, Name) and node.step.id == 'None'):
893                self.visit(node.step)
894
895    def visit_ExtSlice(self, node):
896        for idx, item in node.dims:
897            if idx:
898                self.write_python(', ')
899            self.visit(item)
900
901    def visit_Yield(self, node):
902        self.write_python('yield ')
903        self.visit(node.value)
904
905    def visit_Lambda(self, node):
906        self.write_python('lambda ')
907        self.visit(node.args)
908        self.write_python(': ')
909        self.visit(node.body)
910
911    def visit_Ellipsis(self, node):
912        self.write_python('Ellipsis')
913
914    def generator_visit(left, right):
915        def visit(self, node):
916            self.write_python(left)
917            self.visit(node.elt)
918            for comprehension in node.generators:
919                self.visit(comprehension)
920            self.write_python(right)
921        return visit
922
923    visit_ListComp = generator_visit('[', ']')
924    visit_GeneratorExp = generator_visit('(', ')')
925    visit_SetComp = generator_visit('{', '}')
926    del generator_visit
927
928    def visit_DictComp(self, node):
929        self.write_python('{')
930        self.visit(node.key)
931        self.write_python(': ')
932        self.visit(node.value)
933        for comprehension in node.generators:
934            self.visit(comprehension)
935        self.write_python('}')
936
937    def visit_IfExp(self, node):
938        self.visit(node.body)
939        self.write_c(' if ')
940        self.visit(node.test)
941        self.write_c(' else ')
942        self.visit(node.orelse)
943
944    def visit_Starred(self, node):
945        self.write_c('*')
946        self.visit(node.value)
947
948    def visit_Repr(self, node):
949        # XXX: python 2.6 only
950        self.write_c('`')
951        self.visit(node.value)
952        self.write_python('`')
953
954    # Helper Nodes
955
956    def visit_alias(self, node):
957        self.write_python(node.name)
958        if node.asname is not None:
959            self.write_python(' as ' + node.asname)
960
961    def visit_comprehension(self, node):
962        self.write_c(' for ')
963        self.visit(node.target)
964        self.write_python(' in ')
965        self.visit(node.iter)
966        if node.ifs:
967            for if_ in node.ifs:
968                self.write_python(' if ')
969                self.visit(if_)
970
971#    def visit_excepthandler(self, node):
972#        self.newline(node)
973#        self.write_python('except')
974#        if node.type is not None:
975#            self.write_python(' ')
976#            self.visit(node.type)
977#            if node.name is not None:
978#                self.write_python(' as ')
979#                self.visit(node.name)
980#        self.body(node.body)
981
982    def visit_arguments(self, node):
983        self.signature(node)
984
985def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
986    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
987    if (q > p):
988        q = p + 17
989        p = q - 5
990    z3 = -8
991    inten = (porod_scale / q ** porod_exp
992                + lorentz_scale / (1 + z ** lorentz_exp))
993    return inten
994
995def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
996    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
997    if (q > p):
998        q = p + 17
999        p = q - 5
1000    elif (q == p):
1001        q = p * q
1002        q *= z1
1003        p = z1
1004    elif (q == 17):
1005        q = p * q - 17
1006    else:
1007        q += 7
1008    z3 = -8
1009    inten = (porod_scale / q ** porod_exp
1010                + lorentz_scale / (1 + z ** lorentz_exp))
1011    return inten
1012
1013def print_function(f=None):
1014    """
1015    Print out the code for the function
1016    """
1017    # Include some comments to see if they get printed
1018    import ast
1019    import inspect
1020    if f is not None:
1021        tree = ast.parse(inspect.getsource(f))
1022        tree_source = to_source (tree)
1023        print(tree_source)
1024
1025def translate (functions, constants=0):
1026    sniplets = []
1027    fname = functions[1]
1028    python_file = open (fname, "r")
1029    source = python_file.read()
1030    python_file.close()
1031    tree = ast.parse (source)
1032    sniplet = to_source (tree, functions) # in the future add filename, offset, constants
1033    sniplets.append(sniplet)
1034    return ("\n".join(sniplets))
1035
1036def get_file_names ():
1037    fname_in = ""
1038    fname_out = ""
1039    if (len(sys.argv) > 1):
1040        fname_in = sys.argv[1]
1041        fname_base = os.path.splitext (fname_in)
1042        if (len (sys.argv) == 2):
1043            fname_out = str (fname_base[0]) + '.c'
1044        else:
1045            fname_out = sys.argv[2]
1046        if (len (fname_in) > 0):
1047            python_file = open (sys.argv[1], "r")
1048            if (len (fname_out) > 0):
1049                file_out = open (fname_out, "w+")
1050    return len(sys.argv), fname_in, fname_out
1051
1052if __name__ == "__main__":
1053    import os
1054    print("Parsing...using Python" + sys.version)
1055    try:
1056        fname_in = ""
1057        fname_out = ""
1058        if (len (sys.argv) == 1):
1059            print ("Usage:\npython parse01.py <infile> [<outfile>] (if omitted, output file is '<infile>.c'")
1060        else:
1061            fname_in = sys.argv[1]
1062            fname_base = os.path.splitext (fname_in)
1063            if (len (sys.argv) == 2):
1064                fname_out = str (fname_base[0]) + '.c'
1065            else:
1066                fname_out = sys.argv[2]
1067            if (len (fname_in) > 0):
1068                python_file = open (sys.argv[1], "r")
1069                if (len (fname_out) > 0):
1070                    file_out = open (fname_out, "w+")
1071                functions = ["MultAsgn", "Iq41", "Iq2"]
1072                tpls = [functions, fname_in, 0]
1073                c_txt = translate (tpls)
1074                file_out.write (c_txt)
1075                file_out.close()
1076    except Exception as excp:
1077        print ("Error:\n" + str(excp.args))
1078    print("...Done")
Note: See TracBrowser for help on using the repository browser.