source: sasmodels/sasmodels/py2c.py @ 71c5f4d

Last change on this file since 71c5f4d was 71c5f4d, checked in by Omer Eisenberg <omereis@…>, 6 years ago

included constants in C. Fixed bug in C for loop

  • Property mode set to 100644
File size: 38.9 KB
Line 
1
2"""
3    codegen
4    ~~~~~~~
5
6    Extension to ast that allow ast -> python code generation.
7
8    :copyright: Copyright 2008 by Armin Ronacher.
9    :license: BSD.
10"""
11"""
12    Variables definition in C
13    -------------------------
14    Defining variables within the Translate function is a bit of a guess work, using following rules.
15    *   By default, a variable is a 'double'.
16    *   Variable in a for loop is an int.
17    *   Variable that is references with brackets is an array of doubles. The variable within the brackets
18        is integer. For example, in the reference 'var1[var2]', var1 is a double array, and var2 is an integer.
19    *   Assignment to an argument makes that argument an array, and the index in that assignment is 0.
20        For example, the following python code
21            def func (arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code
24            double func (double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the following C code
29            def func (arg1, arg2):          double func (double arg1) {
30                arg2 = 17.                      arg2[0] = 17.0;
31                                            }
32    *   All functions are defined as double, even if there is no return statement.
33
34
35Update Notes
36============
3711/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It shold insert 4 blanks per indentation level.
38                    The 'body' method will combine all the strings, by adding the 'current_statement' to the c_proc string list
39   11/2017, OE: variables, argument definition implemented.
40   Note: An argument is considered an array if it is the target of an assignment. In that case it is translated to <var>[0]
4111/27/2017, OE: 'pow' basicly working
42  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
43  /12/2017, OE: Power function, including special cases of square(x) (pow(x,2)) and cube(x) (pow(x,3)), implemented in translate_power,
44                called from visit_BinOp
4512/07/2017, OE: Translation of integer division, '\\' in python, implemented in translate_integer_divide, called from visit_BinOp
4612/07/2017, OE: C variable definition handled in 'define_C_Vars'
47"""
48import ast
49import sys
50import os
51from ast import NodeVisitor
52
53BINOP_SYMBOLS = {}
54BINOP_SYMBOLS[ast.Add] = '+'
55BINOP_SYMBOLS[ast.Sub] = '-'
56BINOP_SYMBOLS[ast.Mult] = '*'
57BINOP_SYMBOLS[ast.Div] = '/'
58BINOP_SYMBOLS[ast.Mod] = '%'
59BINOP_SYMBOLS[ast.Pow] = '**'
60BINOP_SYMBOLS[ast.LShift] = '<<'
61BINOP_SYMBOLS[ast.RShift] = '>>'
62BINOP_SYMBOLS[ast.BitOr] = '|'
63BINOP_SYMBOLS[ast.BitXor] = '^'
64BINOP_SYMBOLS[ast.BitAnd] = '&'
65BINOP_SYMBOLS[ast.FloorDiv] = '//'
66
67BOOLOP_SYMBOLS = {}
68BOOLOP_SYMBOLS[ast.And] = '&&'
69BOOLOP_SYMBOLS[ast.Or]  = '||'
70
71CMPOP_SYMBOLS = {}
72CMPOP_SYMBOLS[ast.Eq] = '=='
73CMPOP_SYMBOLS[ast.NotEq] = '!='
74CMPOP_SYMBOLS[ast.Lt] = '<'
75CMPOP_SYMBOLS[ast.LtE] = '<='
76CMPOP_SYMBOLS[ast.Gt] = '>'
77CMPOP_SYMBOLS[ast.GtE] = '>='
78CMPOP_SYMBOLS[ast.Is] = 'is'
79CMPOP_SYMBOLS[ast.IsNot] = 'is not'
80CMPOP_SYMBOLS[ast.In] = 'in'
81CMPOP_SYMBOLS[ast.NotIn] = 'not in'
82
83UNARYOP_SYMBOLS = {}
84UNARYOP_SYMBOLS[ast.Invert] = '~'
85UNARYOP_SYMBOLS[ast.Not] = 'not'
86UNARYOP_SYMBOLS[ast.UAdd] = '+'
87UNARYOP_SYMBOLS[ast.USub] = '-'
88
89
90#def to_source(node, indent_with=' ' * 4, add_line_information=False):
91def to_source(node, func_name, global_vectors):
92    """This function can convert a node tree back into python sourcecode.
93    This is useful for debugging purposes, especially if you're dealing with
94    custom asts not generated by python itself.
95
96    It could be that the sourcecode is evaluable when the AST itself is not
97    compilable / evaluable.  The reason for this is that the AST contains some
98    more data than regular sourcecode does, which is dropped during
99    conversion.
100
101    Each level of indentation is replaced with `indent_with`.  Per default this
102    parameter is equal to four spaces as suggested by PEP 8, but it might be
103    adjusted to match the application's styleguide.
104
105    If `add_line_information` is set to `True` comments for the line numbers
106    of the nodes are added to the output.  This can be used to spot wrong line
107    number information of statement nodes.
108
109    global_vectors arguments is a list of strings, holding the names of global
110    variables already declared as arrays
111    """
112#    print(str(node))
113#    return
114    generator = SourceGenerator(' ' * 4, False)
115    generator.global_vectors = global_vectors
116#    generator.required_functions = func_name
117    generator.visit(node)
118
119#    return ''.join(generator.result)
120    return ''.join(generator.c_proc)
121
122EXTREN_DOUBLE = "extern double "
123
124def get_c_array_value(val):
125    val_str = str(val).strip()
126    val_str = val_str.replace('[','')
127    val_str = val_str.replace(']','')
128    val_str = val_str.strip()
129    double_blank = "  " in val_str
130    while (double_blank):
131        val_str = val_str.replace('  ',' ')
132        double_blank = "  " in val_str
133    val_str = val_str.replace(' ',',')
134    return (val_str)
135
136def write_include_const (snippets, constants):
137    c_globals=[]
138    snippets.append("#include <math.h>\n")
139    const_names = constants.keys()
140    for name in const_names:
141        c_globals.append(name)
142        val = constants[name]
143        var_dcl = "double "# + " = "
144        if (hasattr(val, "__len__")): # the value is an array
145            val_str = get_c_array_value(val)
146            var_dcl += str(name) + "[] = {" + str(val_str) + "}"
147        else:
148            var_dcl += str(name) + " = " + str(val)
149        var_dcl += ";\n"
150        snippets.append(var_dcl)
151           
152    return (c_globals)
153#    for c_var in constants:
154#        sTmp = EXTREN_DOUBLE + str(c_var) + ";"
155#        c_externs.append (EXTREN_DOUBLE + c_var + ";")
156#    snippets.append (c_externs)
157
158def isevaluable(s):
159    try:
160        eval(s)
161        return True
162    except:
163        return False
164   
165class SourceGenerator(NodeVisitor):
166    """This visitor is able to transform a well formed syntax tree into python
167    sourcecode.  For more details have a look at the docstring of the
168    `node_to_source` function.
169    """
170
171    def __init__(self, indent_with, add_line_information=False):
172        self.result = []
173        self.indent_with = indent_with
174        self.add_line_information = add_line_information
175        self.indentation = 0
176        self.new_lines = 0
177        self.c_proc = []
178# for C
179        self.signature_line = 0
180        self.arguments = []
181        self.name = ""
182        self.warnings = []
183        self.Statements = []
184        self.current_statement = ""
185        self.strMethodSignature = ""
186        self.C_Vars = []
187        self.C_IntVars = []
188        self.MathIncludeed = False
189        self.C_Pointers = []
190        self.C_DclPointers = []
191        self.C_Functions = []
192        self.C_Vectors = []
193        self.SubRef = False
194        self.InSubscript = False
195        self.Tuples = []
196        self.required_functions = []
197        self.is_sequence = False
198        self.visited_args = False
199        self.global_vectors = []
200
201    def write_python(self, x):
202        if self.new_lines:
203            if self.result:
204                self.result.append('\n' * self.new_lines)
205            self.result.append(self.indent_with * self.indentation)
206            self.new_lines = 0
207        self.result.append(x)
208
209    def write_c(self, x):
210        self.current_statement += x
211       
212    def add_c_line(self, x):
213        string = ''
214        for i in range (self.indentation):
215            string += ("    ")
216        string += str(x)
217        self.c_proc.append (str(string + "\n"))
218        x = ''
219
220    def add_current_line (self):
221        if (len(self.current_statement) > 0):
222            self.add_c_line (self.current_statement)
223            self.current_statement = ''
224
225    def AddUniqueVar (self, new_var):
226        if ((new_var not in self.C_Vars)):
227            self.C_Vars.append (str (new_var))
228
229    def WriteSincos (self, node):
230        angle = str(node.args[0].id)
231        self.write_c (node.args[1].id + " = sin (" + angle + ");")
232        self.add_current_line()
233        self.write_c (node.args[2].id + " = cos (" + angle + ");")
234        self.add_current_line()
235        for arg in node.args:
236            self.AddUniqueVar (arg.id)
237
238    def newline(self, node=None, extra=0):
239        self.new_lines = max(self.new_lines, 1 + extra)
240        if node is not None and self.add_line_information:
241            self.write_c('# line: %s' % node.lineno)
242            self.new_lines = 1
243        if (len(self.current_statement)):
244            self.Statements.append (self.current_statement)
245            self.current_statement = ''
246
247    def body(self, statements):
248        if (len(self.current_statement)):
249            self.add_current_line ()
250        self.new_line = True
251        self.indentation += 1
252        for stmt in statements:
253            target_name = ''
254            if (hasattr (stmt, 'targets')):
255                if (hasattr(stmt.targets[0],'id')):
256                    target_name = stmt.targets[0].id # target name needed for debug only
257            self.visit(stmt)
258        self.add_current_line () # just for breaking point. to be deleted.
259        self.indentation -= 1
260
261    def body_or_else(self, node):
262        self.body(node.body)
263        if node.orelse:
264            self.newline()
265            self.write_c('else:')
266            self.body(node.orelse)
267
268    def signature(self, node):
269        want_comma = []
270        def write_comma():
271            if want_comma:
272                self.write_c(', ')
273            else:
274                want_comma.append(True)
275# for C
276        for arg in node.args:
277            self.arguments.append (arg.arg)
278
279        padding = [None] * (len(node.args) - len(node.defaults))
280        for arg, default in zip(node.args, padding + node.defaults):
281            if default is not None:
282                self.warnings.append ("Default Parameter unknown to C")
283                w_str = "Default Parameters are unknown to C: '" + arg.arg + " = " + str(default.n) + "'"
284                self.warnings.append (w_str)
285#                self.write_python('=')
286#                self.visit(default)
287
288    def decorators(self, node):
289        for decorator in node.decorator_list:
290            self.newline(decorator)
291            self.write_python('@')
292            self.visit(decorator)
293
294    # Statements
295
296    def visit_Assert(self, node):
297        self.newline(node)
298        self.write_c('assert ')
299        self.visit(node.test)
300        if node.msg is not None:
301           self.write_python(', ')
302           self.visit(node.msg)
303
304    def define_C_Vars (self, target):
305        if (hasattr (target, 'id')):
306# a variable is considered an array if it apears in the agrument list
307# and being assigned to. For example, the variable p in the following
308# sniplet is a pointer, while q is not
309# def somefunc (p, q):
310#  p = q + 1
311#  return
312#
313            if (target.id not in self.C_Vars):
314                if (target.id in self.arguments):
315                    idx = self.arguments.index (target.id)
316                    new_target = self.arguments[idx] + "[0]"
317                    if (new_target not in self.C_Pointers):
318                        target.id = new_target
319                        self.C_Pointers.append (self.arguments[idx])
320                else:
321                    self.C_Vars.append (target.id)
322
323    def visit_Assign(self, node):
324        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
325            if idx:
326                self.write_c(' = ')
327            self.define_C_Vars (target)
328            self.visit(target)
329        if (len(self.Tuples) > 0):
330            tplTargets = list (self.Tuples)
331            self.Tuples.clear()
332        self.write_c(' = ')
333        self.is_sequence = False
334        self.visited_args = False
335        self.visit(node.value)
336        self.write_c(';')
337        self.add_current_line ()
338        for n,item in enumerate (self.Tuples):
339            self.visit(tplTargets[n])
340            self.write_c(' = ')
341            self.visit(item)
342            self.write_c(';')
343            self.add_current_line ()
344        if ((self.is_sequence) and (not self.visited_args)):
345            for target  in node.targets:
346                if (hasattr (target, 'id')):
347                    if ((target.id in self.C_Vars) and (target.id not in self.C_DclPointers)):
348                        if (target.id not in self.C_DclPointers):
349                            self.C_DclPointers.append (target.id)
350                            if (target.id in self.C_Vars):
351                                self.C_Vars.remove(target.id)
352        self.current_statement = ''
353
354    def visit_AugAssign(self, node):
355        if (node.target.id not in self.C_Vars):
356            if (node.target.id not in self.arguments):
357                self.C_Vars.append (node.target.id)
358        self.visit(node.target)
359        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
360        self.visit(node.value)
361        self.write_c(';')
362        self.add_current_line ()
363
364    def visit_ImportFrom(self, node):
365        self.newline(node)
366        self.write_python('from %s%s import ' % ('.' * node.level, node.module))
367        for idx, item in enumerate(node.names):
368            if idx:
369                self.write_python(', ')
370            self.write_python(item)
371
372    def visit_Import(self, node):
373        self.newline(node)
374        for item in node.names:
375            self.write_python('import ')
376            self.visit(item)
377
378    def visit_Expr(self, node):
379        self.newline(node)
380        self.generic_visit(node)
381
382    def listToDeclare(self, Vars):
383        s = ''
384        if (len(Vars) > 0):
385            s = ",".join (Vars)
386        return (s)
387
388    def write_C_Pointers (self, start_var):
389        if (len (self.C_DclPointers) > 0):
390            vars = ""
391            for c_ptr in self.C_DclPointers:
392                if (len(vars) > 0):
393                    vars += ", "
394                if (c_ptr not in self.arguments):
395                    vars += "*" + c_ptr
396                if (c_ptr in self.C_Vars):
397                    if (c_ptr in self.C_Vars):
398                        self.C_Vars.remove (c_ptr)
399            if (len(vars) > 0):
400                c_dcl = "    double " + vars + ";"
401                self.c_proc.insert (start_var, c_dcl + "\n")
402                start_var += 1
403        return start_var
404
405    def insert_C_Vars (self, start_var):
406        fLine = False
407        if (len(self.C_Vars) > 0):
408            s = self.listToDeclare(self.C_Vars)
409            self.c_proc.insert (start_var, "    double " + s + ";\n")
410            fLine = True
411            start_var += 1
412        if (len(self.C_IntVars) > 0):
413            s = self.listToDeclare(self.C_IntVars)
414            self.c_proc.insert (start_var, "    int " + s + ";\n")
415            fLine = True
416            start_var += 1
417        if (len (self.C_Vectors) > 0):
418            s = self.listToDeclare(self.C_Vectors)
419            for n in range(len(self.C_Vectors)):
420                name = "vec" + str(n+1)
421                c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};"
422                self.c_proc.insert (start_var, c_dcl + "\n")
423                start_var += 1
424        if (len (self.C_Pointers) > 0):
425            vars = ""
426            for n in range(len(self.C_Pointers)):
427                if (len(vars) > 0):
428                    vars += ", "
429                if (self.C_Pointers[n] not in self.arguments):
430                    vars += "*" + self.C_Pointers[n]
431            if (len(vars) > 0):
432                c_dcl = "    double " + vars + ";"
433                self.c_proc.insert (start_var, c_dcl + "\n")
434                start_var += 1
435        self.C_Vars.clear()
436        self.C_IntVars.clear()
437        self.C_Vectors.clear()
438        self.C_Pointers.clear()
439        self.C_DclPointers.clear()
440        if (fLine == True):
441            self.c_proc.insert (start_var, "\n")
442        return
443        s = ''
444        for n in range(len(self.C_Vars)):
445            s += str(self.C_Vars[n])
446            if n < len(self.C_Vars) - 1:
447                s += ", "
448        if (len(s) > 0):
449            self.c_proc.insert (start_var, "    double " + s + ";\n")
450            self.c_proc.insert (start_var + 1, "\n")
451
452#    def writeInclude (self):
453#        if (self.MathIncludeed == False):
454#            self.add_c_line("#include <math.h>\n")
455#            self.add_c_line("static double pi = 3.14159265359;\n")
456#            self.MathIncludeed = True
457
458    def ListToString (self, strings):
459        s = ''
460        for n in range(len(strings)):
461            s += strings[n]
462            if (n < (len(strings) - 1)):
463                s += ", "
464        return (s)
465
466    def getMethodSignature (self):
467#        args_str = ListToString (self.arguments)
468        args_str = ''
469        for n in range(len(self.arguments)):
470            args_str += "double " + self.arguments[n]
471            if (n < (len(self.arguments) - 1)):
472                args_str += ", "
473        return (args_str)
474#        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
475
476    def InsertSignature (self):
477        args_str = ''
478        for n in range(len(self.arguments)):
479            args_str += "double " + self.arguments[n]
480            if (self.arguments[n] in self.C_Pointers):
481                args_str += "[]"
482            if (n < (len(self.arguments) - 1)):
483                args_str += ", "
484        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
485        if (self.signature_line >= 0):
486            self.c_proc.insert (self.signature_line, self.strMethodSignature)
487
488    def visit_FunctionDef(self, node):
489        self.newline(extra=1)
490        self.decorators(node)
491        self.newline(node)
492        self.arguments = []
493        self.name = node.name
494#        if self.name not in self.required_functions[0]:
495#            return
496        print("Parsing '" + self.name + "'")
497        args_str = ""
498
499        self.visit(node.args)
500# for C
501#        self.writeInclude()
502        self.getMethodSignature ()
503# for C
504        self.signature_line = len(self.c_proc)
505#        self.add_c_line(self.strMethodSignature)
506        self.add_c_line("\n{")
507        start_vars = len(self.c_proc) + 1
508        self.body(node.body)
509        self.add_c_line("}\n")
510        self.InsertSignature ()
511        self.insert_C_Vars (start_vars)
512        self.C_Pointers = []
513
514    def visit_ClassDef(self, node):
515        have_args = []
516        def paren_or_comma():
517            if have_args:
518                self.write_python(', ')
519            else:
520                have_args.append(True)
521                self.write_python('(')
522
523        self.newline(extra=2)
524        self.decorators(node)
525        self.newline(node)
526        self.write_python('class %s' % node.name)
527        for base in node.bases:
528            paren_or_comma()
529            self.visit(base)
530        # XXX: the if here is used to keep this module compatible
531        #      with python 2.6.
532        if hasattr(node, 'keywords'):
533            for keyword in node.keywords:
534                paren_or_comma()
535                self.write_python(keyword.arg + '=')
536                self.visit(keyword.value)
537            if node.starargs is not None:
538                paren_or_comma()
539                self.write_python('*')
540                self.visit(node.starargs)
541            if node.kwargs is not None:
542                paren_or_comma()
543                self.write_python('**')
544                self.visit(node.kwargs)
545        self.write_python(have_args and '):' or ':')
546        self.body(node.body)
547
548    def visit_If(self, node):
549        self.write_c('if ')
550        self.visit(node.test)
551        self.write_c(' {')
552        self.body(node.body)
553        self.add_c_line('}')
554        while True:
555            else_ = node.orelse
556            if len(else_) == 0:
557                break
558#            elif hasattr (else_, 'orelse'):
559            elif len(else_) == 1 and isinstance(else_[0], ast.If):
560                node = else_[0]
561#                self.newline()
562                self.write_c('else if ')
563                self.visit(node.test)
564                self.write_c(' {')
565                self.body(node.body)
566                self.add_current_line ()
567                self.add_c_line('}')
568#                break
569            else:
570                self.newline()
571                self.write_c('else {')
572                self.body(node.body)
573                self.add_c_line('}')
574                break
575
576    def getNodeLineNo (self, node):
577        line_number = -1
578        if (hasattr (node,'value')):
579            line_number = node.value.lineno
580        elif hasattr(node,'iter'):
581            if hasattr(node.iter,'lineno'):
582                line_number = node.iter.lineno
583        return (line_number)
584
585    def visit_For(self, node):
586        if (len(self.current_statement) > 0):
587            self.add_c_line(self.current_statement)
588            self.current_statement = ''
589        fForDone = False
590        if (hasattr(node.iter,'func')):
591            if (hasattr (node.iter.func,'id')):
592                if (node.iter.func.id == 'range'):
593                    if ('n' not in self.C_IntVars):
594                        self.C_IntVars.append ('n')
595                    self.write_c ("for (n=0 ; n < len(")
596                    for arg in node.iter.args:
597                        self.visit(arg)
598                    self.write_c (") ; n++) {")
599                    self.body_or_else(node)
600                    self.write_c ("}")
601                    self.add_current_line () # just for breaking point. to be deleted.
602                    fForDone = True
603        if (fForDone == False):
604            line_number = self.getNodeLineNo (node)
605            self.current_statement = ''
606            self.write_c('for ')
607            self.visit(node.target)
608            self.write_c(' in ')
609            self.visit(node.iter)
610            self.write_c(':')
611            errStr = "Conversion Error in function " + self.name + ", Line #" + str (line_number)
612            errStr += "\nPython for expression not supported: '" + self.current_statement + "'"
613            raise Exception(errStr)
614
615    def visit_While(self, node):
616        self.newline(node)
617        self.write_c('while ')
618        self.visit(node.test)
619        self.write_c(':')
620        self.body_or_else(node)
621
622    def visit_With(self, node):
623        self.newline(node)
624        self.write_python('with ')
625        self.visit(node.context_expr)
626        if node.optional_vars is not None:
627            self.write_python(' as ')
628            self.visit(node.optional_vars)
629        self.write_python(':')
630        self.body(node.body)
631
632    def visit_Pass(self, node):
633        self.newline(node)
634        self.write_python('pass')
635
636    def visit_Print(self, node):
637# XXX: python 2.6 only
638        self.newline(node)
639        self.write_c('print ')
640        want_comma = False
641        if node.dest is not None:
642            self.write_c(' >> ')
643            self.visit(node.dest)
644            want_comma = True
645        for value in node.values:
646            if want_comma:
647                self.write_c(', ')
648            self.visit(value)
649            want_comma = True
650        if not node.nl:
651            self.write_c(',')
652
653    def visit_Delete(self, node):
654        self.newline(node)
655        self.write_python('del ')
656        for idx, target in enumerate(node):
657            if idx:
658                self.write_python(', ')
659            self.visit(target)
660
661    def visit_TryExcept(self, node):
662        self.newline(node)
663        self.write_python('try:')
664        self.body(node.body)
665        for handler in node.handlers:
666            self.visit(handler)
667
668    def visit_TryFinally(self, node):
669        self.newline(node)
670        self.write_python('try:')
671        self.body(node.body)
672        self.newline(node)
673        self.write_python('finally:')
674        self.body(node.finalbody)
675
676    def visit_Global(self, node):
677        self.newline(node)
678        self.write_python('global ' + ', '.join(node.names))
679
680    def visit_Nonlocal(self, node):
681        self.newline(node)
682        self.write_python('nonlocal ' + ', '.join(node.names))
683
684    def visit_Return(self, node):
685        self.newline(node)
686        if node.value is None:
687            self.write_c('return')
688        else:
689            self.write_c('return (')
690            self.visit(node.value)
691        self.write_c(');')
692#      self.add_current_statement (self)
693        self.add_c_line (self.current_statement)
694        self.current_statement = ''
695
696    def visit_Break(self, node):
697        self.newline(node)
698        self.write_c('break')
699
700    def visit_Continue(self, node):
701        self.newline(node)
702        self.write_c('continue')
703
704    def visit_Raise(self, node):
705        # XXX: Python 2.6 / 3.0 compatibility
706        self.newline(node)
707        self.write_python('raise')
708        if hasattr(node, 'exc') and node.exc is not None:
709            self.write_python(' ')
710            self.visit(node.exc)
711            if node.cause is not None:
712                self.write_python(' from ')
713                self.visit(node.cause)
714        elif hasattr(node, 'type') and node.type is not None:
715            self.visit(node.type)
716            if node.inst is not None:
717                self.write_python(', ')
718                self.visit(node.inst)
719            if node.tback is not None:
720                self.write_python(', ')
721                self.visit(node.tback)
722
723    # Expressions
724
725    def visit_Attribute(self, node):
726        errStr = "Conversion Error in function " + self.name + ", Line #" + str (node.value.lineno)
727        errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'"
728        raise Exception(errStr)
729        self.visit(node.value)
730        self.write_python('.' + node.attr)
731
732    def visit_Call(self, node):
733        want_comma = []
734        def write_comma():
735            if want_comma:
736                self.write_c(', ')
737            else:
738                want_comma.append(True)
739        if (hasattr (node.func, 'id')):
740            if (node.func.id not in self.C_Functions):
741                self.C_Functions.append (node.func.id)
742            if (node.func.id == 'abs'):
743                self.write_c ("fabs ")
744            elif (node.func.id == 'int'):
745                self.write_c('(int) ')
746            elif (node.func.id == "SINCOS"):
747                self.WriteSincos (node)
748                return
749            else:
750                self.visit(node.func)
751        else:
752            self.visit(node.func)
753#self.C_Functions
754        self.write_c('(')
755        for arg in node.args:
756            write_comma()
757            self.visited_args = True 
758            self.visit(arg)
759        for keyword in node.keywords:
760            write_comma()
761            self.write_c(keyword.arg + '=')
762            self.visit(keyword.value)
763        if hasattr (node, 'starargs'):
764            if node.starargs is not None:
765                write_comma()
766                self.write_c('*')
767                self.visit(node.starargs)
768        if hasattr (node, 'kwargs'):
769            if node.kwargs is not None:
770                write_comma()
771                self.write_c('**')
772                self.visit(node.kwargs)
773        self.write_c(')')
774
775    def visit_Name(self, node):
776        self.write_c(node.id)
777        if ((node.id in self.C_Pointers) and (not self.SubRef)):
778            self.write_c("[0]")
779        name = ""
780        sub = node.id.find("[")
781        if (sub > 0):
782            name = node.id[0:sub].strip()
783        else:
784            name = node.id
785#       add variable to C_Vars if it ins't there yet, not an argument and not a number
786        if (name not in self.global_vectors):
787            if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)):
788                if (self.InSubscript):
789                    self.C_IntVars.append (node.id)
790                else:
791                    self.C_Vars.append (node.id)
792
793    def visit_Str(self, node):
794        self.write_c(repr(node.s))
795
796    def visit_Bytes(self, node):
797        self.write_c(repr(node.s))
798
799    def visit_Num(self, node):
800        self.write_c(repr(node.n))
801
802    def visit_Tuple(self, node):
803        for idx, item in enumerate(node.elts):
804            if idx:
805                self.Tuples.append(item)
806            else:
807                self.visit(item)
808
809    def sequence_visit(left, right):
810        def visit(self, node):
811            self.is_sequence = True
812            s = ""
813            for idx, item in enumerate(node.elts):
814                if ((idx > 0) and (len(s) > 0)):
815                    s += ', '
816                if (hasattr (item, 'id')):
817                    s += item.id
818                elif (hasattr(item,'n')):
819                    s += str(item.n)
820            if (len(s) > 0):
821                self.C_Vectors.append (s)
822                vec_name = "vec"  + str(len(self.C_Vectors))
823                self.write_c(vec_name)
824                vec_name += "#"
825        return visit
826
827    visit_List = sequence_visit('[', ']')
828    visit_Set = sequence_visit('{', '}')
829    del sequence_visit
830
831    def visit_Dict(self, node):
832        self.write_python('{')
833        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
834            if idx:
835                self.write_python(', ')
836            self.visit(key)
837            self.write_python(': ')
838            self.visit(value)
839        self.write_python('}')
840
841    def get_special_power (self, string):
842        function_name = ''
843        is_negative_exp = False
844        if (isevaluable(str(self.current_statement))):
845            exponent = eval(string)
846            is_negative_exp = exponent < 0
847            abs_exponent = abs(exponent)
848            if (abs_exponent == 2):
849                function_name = "square"
850            elif (abs_exponent == 3):
851                function_name = "cube"
852            elif (abs_exponent == 0.5):
853                function_name = "sqrt"
854            elif (abs_exponent == 1.0/3.0):
855                function_name = "cbrt"
856        if (function_name == ''):
857            function_name = "pow"
858        return function_name, is_negative_exp
859
860    def translate_power (self, node):
861# get exponent by visiting the right hand argument.
862        function_name = "pow"
863        temp_statement = self.current_statement
864# 'visit' functions write the results to the 'current_statement' class memnber
865# Here, a temporary variable, 'temp_statement', is used, that enables the
866# use of the 'visit' function
867        self.current_statement = ''
868        self.visit(node.right)
869        exponent = self.current_statement.replace(' ','')
870        function_name, is_negative_exp = self.get_special_power (self.current_statement)
871        self.current_statement = temp_statement
872        if (is_negative_exp):
873            self.write_c ("1.0 / (")
874        self.write_c (function_name + " (")
875        self.visit(node.left)
876        if (function_name == "pow"):
877            self.write_c(", ")
878            self.visit(node.right)
879        self.write_c(")")
880        if (is_negative_exp):
881            self.write_c(")")
882        self.write_c(" ")
883
884    def translate_integer_divide (self, node):
885        self.write_c ("(int) (")
886        self.visit(node.left)
887        self.write_c (") / (int) (")
888        self.visit(node.right)
889        self.write_c (")")
890
891    def visit_BinOp(self, node):
892        if ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]):
893            self.translate_power (node)
894        elif ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]):
895            self.translate_integer_divide (node)
896        else:
897            self.visit(node.left)
898            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
899            self.visit(node.right)
900#       for C
901    def visit_BoolOp(self, node):
902        self.write_c('(')
903        for idx, value in enumerate(node.values):
904            if idx:
905                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
906            self.visit(value)
907        self.write_c(')')
908
909    def visit_Compare(self, node):
910        self.write_c('(')
911        self.visit(node.left)
912        for op, right in zip(node.ops, node.comparators):
913            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
914            self.visit(right)
915        self.write_c(')')
916
917    def visit_UnaryOp(self, node):
918        self.write_c('(')
919        op = UNARYOP_SYMBOLS[type(node.op)]
920        self.write_c(op)
921        if op == 'not':
922            self.write_c(' ')
923        self.visit(node.operand)
924        self.write_c(')')
925
926    def visit_Subscript(self, node):
927        if (node.value.id not in self.C_Pointers):
928            self.C_Pointers.append (node.value.id)
929        self.SubRef = True
930        self.visit(node.value)
931        self.SubRef = False
932        self.write_c('[')
933        self.InSubscript = True
934        self.visit(node.slice)
935        self.InSubscript = False
936        self.write_c(']')
937
938    def visit_Slice(self, node):
939        if node.lower is not None:
940            self.visit(node.lower)
941        self.write_python(':')
942        if node.upper is not None:
943            self.visit(node.upper)
944        if node.step is not None:
945            self.write_python(':')
946            if not (isinstance(node.step, Name) and node.step.id == 'None'):
947                self.visit(node.step)
948
949    def visit_ExtSlice(self, node):
950        for idx, item in node.dims:
951            if idx:
952                self.write_python(', ')
953            self.visit(item)
954
955    def visit_Yield(self, node):
956        self.write_python('yield ')
957        self.visit(node.value)
958
959    def visit_Lambda(self, node):
960        self.write_python('lambda ')
961        self.visit(node.args)
962        self.write_python(': ')
963        self.visit(node.body)
964
965    def visit_Ellipsis(self, node):
966        self.write_python('Ellipsis')
967
968    def generator_visit(left, right):
969        def visit(self, node):
970            self.write_python(left)
971            self.visit(node.elt)
972            for comprehension in node.generators:
973                self.visit(comprehension)
974            self.write_python(right)
975        return visit
976
977    visit_ListComp = generator_visit('[', ']')
978    visit_GeneratorExp = generator_visit('(', ')')
979    visit_SetComp = generator_visit('{', '}')
980    del generator_visit
981
982    def visit_DictComp(self, node):
983        self.write_python('{')
984        self.visit(node.key)
985        self.write_python(': ')
986        self.visit(node.value)
987        for comprehension in node.generators:
988            self.visit(comprehension)
989        self.write_python('}')
990
991    def visit_IfExp(self, node):
992        self.visit(node.body)
993        self.write_c(' if ')
994        self.visit(node.test)
995        self.write_c(' else ')
996        self.visit(node.orelse)
997
998    def visit_Starred(self, node):
999        self.write_c('*')
1000        self.visit(node.value)
1001
1002    def visit_Repr(self, node):
1003        # XXX: python 2.6 only
1004        self.write_c('`')
1005        self.visit(node.value)
1006        self.write_python('`')
1007
1008    # Helper Nodes
1009
1010    def visit_alias(self, node):
1011        self.write_python(node.name)
1012        if node.asname is not None:
1013            self.write_python(' as ' + node.asname)
1014
1015    def visit_comprehension(self, node):
1016        self.write_c(' for ')
1017        self.visit(node.target)
1018        self.write_python(' in ')
1019        self.visit(node.iter)
1020        if node.ifs:
1021            for if_ in node.ifs:
1022                self.write_python(' if ')
1023                self.visit(if_)
1024
1025#    def visit_excepthandler(self, node):
1026#        self.newline(node)
1027#        self.write_python('except')
1028#        if node.type is not None:
1029#            self.write_python(' ')
1030#            self.visit(node.type)
1031#            if node.name is not None:
1032#                self.write_python(' as ')
1033#                self.visit(node.name)
1034#        self.body(node.body)
1035
1036    def visit_arguments(self, node):
1037        self.signature(node)
1038
1039def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1040    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
1041    if (q > p):
1042        q = p + 17
1043        p = q - 5
1044    z3 = -8
1045    inten = (porod_scale / q ** porod_exp
1046                + lorentz_scale / (1 + z ** lorentz_exp))
1047    return inten
1048
1049def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1050    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
1051    if (q > p):
1052        q = p + 17
1053        p = q - 5
1054    elif (q == p):
1055        q = p * q
1056        q *= z1
1057        p = z1
1058    elif (q == 17):
1059        q = p * q - 17
1060    else:
1061        q += 7
1062    z3 = -8
1063    inten = (porod_scale / q ** porod_exp
1064                + lorentz_scale / (1 + z ** lorentz_exp))
1065    return inten
1066
1067def print_function(f=None):
1068    """
1069    Print out the code for the function
1070    """
1071    # Include some comments to see if they get printed
1072    import ast
1073    import inspect
1074    if f is not None:
1075        tree = ast.parse(inspect.getsource(f))
1076        tree_source = to_source (tree)
1077        print(tree_source)
1078
1079
1080def write_consts (constants):
1081    f = False
1082    if (constants is not None):
1083        f = True
1084
1085def translate(functions, constants):
1086    # type: (List[Tuple[str, str, int]], Dict[str, Any]) -> str
1087#    print ("Size of functions is: " + str(len(functions)))
1088    if (os.path.exists ("xlate.c")):
1089        os.remove("xlate.c")
1090    snippets = []
1091    global_vectors = write_include_const (snippets, constants)
1092    const_keys = constants.keys ()
1093    for source, filename, offset in functions:
1094        tree = ast.parse(source)
1095        snippet = to_source(tree, filename, global_vectors)#, offset)
1096        snippets.append(snippet)
1097    if (constants is not None):
1098        write_consts (constants)
1099    c_text = "\n".join(snippets)
1100    try:
1101        translated = open ("xlate.c", "a+")
1102        translated.write (c_text)
1103        translated.close()
1104    except Exception as excp:
1105        strErr = "Error:\n" + str(excp.args)
1106        print(strErr)
1107    return (c_text)
1108
1109def translate_from_file (functions, constants=0):
1110    print ("Size of functions is: " + len(functions))
1111    sniplets = []
1112    fname = functions[1]
1113    python_file = open (fname, "r")
1114    source = python_file.read()
1115    python_file.close()
1116    tree = ast.parse (source)
1117    sniplet = to_source (tree, functions) # in the future add filename, offset, constants
1118    sniplets.append(sniplet)
1119    return ("\n".join(sniplets))
1120
1121def get_file_names ():
1122    fname_in = ""
1123    fname_out = ""
1124    if (len(sys.argv) > 1):
1125        fname_in = sys.argv[1]
1126        fname_base = os.path.splitext (fname_in)
1127        if (len (sys.argv) == 2):
1128            fname_out = str (fname_base[0]) + '.c'
1129        else:
1130            fname_out = sys.argv[2]
1131        if (len (fname_in) > 0):
1132            python_file = open (sys.argv[1], "r")
1133            if (len (fname_out) > 0):
1134                file_out = open (fname_out, "w+")
1135    return len(sys.argv), fname_in, fname_out
1136
1137if __name__ == "__main__":
1138    import os
1139    print("Parsing...using Python" + sys.version)
1140    try:
1141        fname_in = ""
1142        fname_out = ""
1143        if (len (sys.argv) == 1):
1144            print ("Usage:\npython parse01.py <infile> [<outfile>] (if omitted, output file is '<infile>.c'")
1145        else:
1146            fname_in = sys.argv[1]
1147            fname_base = os.path.splitext (fname_in)
1148            if (len (sys.argv) == 2):
1149                fname_out = str (fname_base[0]) + '.c'
1150            else:
1151                fname_out = sys.argv[2]
1152            if (len (fname_in) > 0):
1153                python_file = open (sys.argv[1], "r")
1154                if (len (fname_out) > 0):
1155                    file_out = open (fname_out, "w+")
1156                functions = ["MultAsgn", "Iq41", "Iq2"]
1157                tpls = [functions, fname_in, 0]
1158                c_txt = translate (tpls)
1159                file_out.write (c_txt)
1160                file_out.close()
1161    except Exception as excp:
1162        print ("Error:\n" + str(excp.args))
1163    print("...Done")
Note: See TracBrowser for help on using the repository browser.