source: sasmodels/sasmodels/py2c.py @ 5c2a0f2

Last change on this file since 5c2a0f2 was 5c2a0f2, checked in by Omer Eisenberg <omereis@…>, 6 years ago

writing constants in C source

  • Property mode set to 100644
File size: 39.7 KB
Line 
1
2"""
3    codegen
4    ~~~~~~~
5
6    Extension to ast that allow ast -> python code generation.
7
8    :copyright: Copyright 2008 by Armin Ronacher.
9    :license: BSD.
10"""
11"""
12    Variables definition in C
13    -------------------------
14    Defining variables within the Translate function is a bit of a guess work,
15    using following rules.
16    *   By default, a variable is a 'double'.
17    *   Variable in a for loop is an int.
18    *   Variable that is references with brackets is an array of doubles. The
19        variable within the brackets is integer. For example, in the
20        reference 'var1[var2]', var1 is a double array, and var2 is an integer.
21    *   Assignment to an argument makes that argument an array, and the index in
22        that assignment is 0.
23        For example, the following python code
24            def func(arg1, arg2):
25                arg2 = 17.
26        is translated to the following C code
27            double func(double arg1)
28            {
29                arg2[0] = 17.0;
30            }
31        For example, the following python code is translated to the following C code
32            def func(arg1, arg2):          double func(double arg1) {
33                arg2 = 17.                      arg2[0] = 17.0;
34                                            }
35    *   All functions are defined as double, even if there is no return statement.
36
37
38Update Notes
39============
4011/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It
41                    shold insert 4 blanks per indentation level.
42                    The 'body' method will combine all the strings, by adding
43                    the 'current_statement' to the c_proc string list
44   11/2017, OE: variables, argument definition implemented.
45   Note: An argument is considered an array if it is the target of an
46        assignment. In that case it is translated to <var>[0]
4711/27/2017, OE: 'pow' basicly working
48  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
49  /12/2017, OE: Power function, including special cases of
50                square(x)(pow(x,2)) and cube(x)(pow(x,3)), implemented in
51                translate_power, called from visit_BinOp
5212/07/2017, OE: Translation of integer division, '\\' in python, implemented
53                in translate_integer_divide, called from visit_BinOp
5412/07/2017, OE: C variable definition handled in 'define_C_Vars'
55              : Python integer division, '//', translated to C in
56                'translate_integer_divide'
5712/15/2017, OE: Precedence maintained by writing opening and closing
58                parenthesesm '(',')', in procedure 'visit_BinOp'.
59"""
60import ast
61import sys
62from ast import NodeVisitor
63
64BINOP_SYMBOLS = {}
65BINOP_SYMBOLS[ast.Add] = '+'
66BINOP_SYMBOLS[ast.Sub] = '-'
67BINOP_SYMBOLS[ast.Mult] = '*'
68BINOP_SYMBOLS[ast.Div] = '/'
69BINOP_SYMBOLS[ast.Mod] = '%'
70BINOP_SYMBOLS[ast.Pow] = '**'
71BINOP_SYMBOLS[ast.LShift] = '<<'
72BINOP_SYMBOLS[ast.RShift] = '>>'
73BINOP_SYMBOLS[ast.BitOr] = '|'
74BINOP_SYMBOLS[ast.BitXor] = '^'
75BINOP_SYMBOLS[ast.BitAnd] = '&'
76BINOP_SYMBOLS[ast.FloorDiv] = '//'
77
78BOOLOP_SYMBOLS = {}
79BOOLOP_SYMBOLS[ast.And] = '&&'
80BOOLOP_SYMBOLS[ast.Or]  = '||'
81
82CMPOP_SYMBOLS = {}
83CMPOP_SYMBOLS[ast.Eq]    = '=='
84CMPOP_SYMBOLS[ast.NotEq] = '!='
85CMPOP_SYMBOLS[ast.Lt] = '<'
86CMPOP_SYMBOLS[ast.LtE] = '<='
87CMPOP_SYMBOLS[ast.Gt] = '>'
88CMPOP_SYMBOLS[ast.GtE] = '>='
89CMPOP_SYMBOLS[ast.Is] = 'is'
90CMPOP_SYMBOLS[ast.IsNot] = 'is not'
91CMPOP_SYMBOLS[ast.In] = 'in'
92CMPOP_SYMBOLS[ast.NotIn] = 'not in'
93
94UNARYOP_SYMBOLS = {}
95UNARYOP_SYMBOLS[ast.Invert] = '~'
96UNARYOP_SYMBOLS[ast.Not] = 'not'
97UNARYOP_SYMBOLS[ast.UAdd] = '+'
98UNARYOP_SYMBOLS[ast.USub] = '-'
99
100
101#def to_source(node, indent_with=' ' * 4, add_line_information=False):
102def to_source(node, func_name, constants=None):
103    """This function can convert a node tree back into python sourcecode.
104    This is useful for debugging purposes, especially if you're dealing with
105    custom asts not generated by python itself.
106
107    It could be that the sourcecode is evaluable when the AST itself is not
108    compilable / evaluable.  The reason for this is that the AST contains some
109    more data than regular sourcecode does, which is dropped during
110    conversion.
111
112    Each level of indentation is replaced with `indent_with`.  Per default this
113    parameter is equal to four spaces as suggested by PEP 8, but it might be
114    adjusted to match the application's styleguide.
115
116    If `add_line_information` is set to `True` comments for the line numbers
117    of the nodes are added to the output.  This can be used to spot wrong line
118    number information of statement nodes.
119    """
120    generator = SourceGenerator(' ' * 4, False, constants)
121#    generator.required_functions = func_name
122    generator.visit(node)
123
124#    return ''.join(generator.result)
125    return ''.join(generator.c_proc)
126
127def isevaluable(s):
128    try:
129        eval(s)
130        return True
131    except:
132        return False
133
134class SourceGenerator(NodeVisitor):
135    """This visitor is able to transform a well formed syntax tree into python
136    sourcecode.  For more details have a look at the docstring of the
137    `node_to_source` function.
138    """
139
140    def __init__(self, indent_with, add_line_information=False, constants=None):
141        self.result = []
142        self.indent_with = indent_with
143        self.add_line_information = add_line_information
144        self.indentation = 0
145        self.new_lines = 0
146        self.c_proc = []
147# for C
148        self.signature_line = 0
149        self.arguments = []
150        self.name = ""
151        self.warnings = []
152        self.Statements = []
153        self.current_statement = ""
154        self.strMethodSignature = ""
155        self.C_Vars = []
156        self.C_IntVars = []
157        self.MathIncludeed = False
158        self.C_Pointers = []
159        self.C_DclPointers = []
160        self.C_Functions = []
161        self.C_Vectors = []
162        self.C_Constants = constants
163        self.SubRef = False
164        self.InSubscript = False
165        self.Tuples = []
166        self.required_functions = []
167        self.is_sequence = False
168        self.visited_args = False
169
170    def write_python(self, x):
171        if self.new_lines:
172            if self.result:
173                self.result.append('\n' * self.new_lines)
174            self.result.append(self.indent_with * self.indentation)
175            self.new_lines = 0
176        self.result.append(x)
177
178    def write_c(self, x):
179        self.current_statement += x
180
181    def add_c_line(self, x):
182        string = ''
183        for i in range(self.indentation):
184            string += ("    ")
185        string += str(x)
186        self.c_proc.append(str(string + "\n"))
187        x = ''
188
189    def add_current_line(self):
190        if(len(self.current_statement) > 0):
191            self.add_c_line(self.current_statement)
192            self.current_statement = ''
193
194    def AddUniqueVar(self, new_var):
195        if((new_var not in self.C_Vars)):
196            self.C_Vars.append(str(new_var))
197
198    def WriteSincos(self, node):
199        angle = str(node.args[0].id)
200        self.write_c(node.args[1].id + " = sin(" + angle + ");")
201        self.add_current_line()
202        self.write_c(node.args[2].id + " = cos(" + angle + ");")
203        self.add_current_line()
204        for arg in node.args:
205            self.AddUniqueVar(arg.id)
206
207    def newline(self, node=None, extra=0):
208        self.new_lines = max(self.new_lines, 1 + extra)
209        if node is not None and self.add_line_information:
210            self.write_c('# line: %s' % node.lineno)
211            self.new_lines = 1
212        if(len(self.current_statement)):
213            self.Statements.append(self.current_statement)
214            self.current_statement = ''
215
216    def body(self, statements):
217        if(len(self.current_statement)):
218            self.add_current_line()
219        self.new_line = True
220        self.indentation += 1
221        for stmt in statements:
222            target_name = ''
223            if(hasattr(stmt, 'targets')):
224                if(hasattr(stmt.targets[0], 'id')):
225                    target_name = stmt.targets[0].id # target name needed for debug only
226            self.visit(stmt)
227        self.add_current_line() # just for breaking point. to be deleted.
228        self.indentation -= 1
229
230    def body_or_else(self, node):
231        self.body(node.body)
232        if node.orelse:
233            self.newline()
234            self.write_c('else:')
235            self.body(node.orelse)
236
237    def signature(self, node):
238        want_comma = []
239        def write_comma():
240            if want_comma:
241                self.write_c(', ')
242            else:
243                want_comma.append(True)
244# for C
245        for arg in node.args:
246            self.arguments.append(arg.arg)
247
248        padding = [None] *(len(node.args) - len(node.defaults))
249        for arg, default in zip(node.args, padding + node.defaults):
250            if default is not None:
251                self.warnings.append("Default Parameter unknown to C")
252                w_str = "Default Parameters are unknown to C: '" + arg.arg + \
253                        " = " + str(default.n) + "'"
254                self.warnings.append(w_str)
255#                self.write_python('=')
256#                self.visit(default)
257
258    def decorators(self, node):
259        for decorator in node.decorator_list:
260            self.newline(decorator)
261            self.write_python('@')
262            self.visit(decorator)
263
264    # Statements
265
266    def visit_Assert(self, node):
267        self.newline(node)
268        self.write_c('assert ')
269        self.visit(node.test)
270        if node.msg is not None:
271            self.write_python(', ')
272            self.visit(node.msg)
273
274    def define_C_Vars(self, target):
275        if(hasattr(target, 'id')):
276# a variable is considered an array if it apears in the agrument list
277# and being assigned to. For example, the variable p in the following
278# sniplet is a pointer, while q is not
279# def somefunc(p, q):
280#  p = q + 1
281#  return
282#
283            if(target.id not in self.C_Vars):
284                if(target.id in self.arguments):
285                    idx = self.arguments.index(target.id)
286                    new_target = self.arguments[idx] + "[0]"
287                    if(new_target not in self.C_Pointers):
288                        target.id = new_target
289                        self.C_Pointers.append(self.arguments[idx])
290                else:
291                    self.C_Vars.append(target.id)
292
293    def add_semi_colon(self):
294        semi_pos = self.current_statement.find(';')
295        if(semi_pos > 0.0):
296            self.current_statement = self.current_statement.replace(';','')
297        self.write_c(';')
298
299    def visit_Assign(self, node):
300        self.add_current_line()
301        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
302            if idx:
303                self.write_c(' = ')
304            self.define_C_Vars(target)
305            self.visit(target)
306        if(len(self.Tuples) > 0):
307            tplTargets = list(self.Tuples)
308            self.Tuples.clear()
309        self.write_c(' = ')
310        self.is_sequence = False
311        self.visited_args = False
312        self.visit(node.value)
313        self.add_semi_colon()
314#        self.write_c(';')
315        self.add_current_line()
316        for n, item in enumerate(self.Tuples):
317            self.visit(tplTargets[n])
318            self.write_c(' = ')
319            self.visit(item)
320            self.add_semi_colon()
321            self.add_current_line()
322        if((self.is_sequence) and (not self.visited_args)):
323            for target in node.targets:
324                if(hasattr(target, 'id')):
325                    if((target.id in self.C_Vars) and(target.id not in self.C_DclPointers)):
326                        if(target.id not in self.C_DclPointers):
327                            self.C_DclPointers.append(target.id)
328                            if(target.id in self.C_Vars):
329                                self.C_Vars.remove(target.id)
330        self.current_statement = ''
331
332    def visit_AugAssign(self, node):
333        if(node.target.id not in self.C_Vars):
334            if(node.target.id not in self.arguments):
335                self.C_Vars.append(node.target.id)
336        self.visit(node.target)
337        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
338        self.visit(node.value)
339        self.add_semi_colon()
340#        self.write_c(';')
341        self.add_current_line()
342
343    def visit_ImportFrom(self, node):
344        self.newline(node)
345        self.write_python('from %s%s import ' %('.' * node.level, node.module))
346        for idx, item in enumerate(node.names):
347            if idx:
348                self.write_python(', ')
349            self.write_python(item)
350
351    def visit_Import(self, node):
352        self.newline(node)
353        for item in node.names:
354            self.write_python('import ')
355            self.visit(item)
356
357    def visit_Expr(self, node):
358        self.newline(node)
359        self.generic_visit(node)
360
361    def listToDeclare(self, Vars):
362        s = ''
363        if(len(Vars) > 0):
364            s = ",".join(Vars)
365        return(s)
366
367    def write_C_Pointers(self, start_var):
368        if(len(self.C_DclPointers) > 0):
369            vars = ""
370            for c_ptr in self.C_DclPointers:
371                if(len(vars) > 0):
372                    vars += ", "
373                if(c_ptr not in self.arguments):
374                    vars += "*" + c_ptr
375                if(c_ptr in self.C_Vars):
376                    if(c_ptr in self.C_Vars):
377                        self.C_Vars.remove(c_ptr)
378            if(len(vars) > 0):
379                c_dcl = "    double " + vars + ";"
380                self.c_proc.insert(start_var, c_dcl + "\n")
381                start_var += 1
382        return start_var
383
384    def insert_C_Vars(self, start_var):
385        fLine = False
386        start_var = self.write_C_Pointers(start_var)
387        if(len(self.C_IntVars) > 0):
388            for var in self.C_IntVars:
389                if(var in self.C_Vars):
390                    self.C_Vars.remove(var)
391            s = self.listToDeclare(self.C_IntVars)
392            self.c_proc.insert(start_var, "    int " + s + ";\n")
393            fLine = True
394            start_var += 1
395
396        if(len(self.C_Vars) > 0):
397            s = self.listToDeclare(self.C_Vars)
398            self.c_proc.insert(start_var, "    double " + s + ";\n")
399            fLine = True
400            start_var += 1
401#        if(len(self.C_IntVars) > 0):
402#            s = self.listToDeclare(self.C_IntVars)
403#            self.c_proc.insert(start_var, "    int " + s + ";\n")
404#            fLine = True
405#            start_var += 1
406        if(len(self.C_Vectors) > 0):
407            s = self.listToDeclare(self.C_Vectors)
408            for n in range(len(self.C_Vectors)):
409                name = "vec" + str(n+1)
410                c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};"
411                self.c_proc.insert(start_var, c_dcl + "\n")
412                start_var += 1
413        self.C_Vars.clear()
414        self.C_IntVars.clear()
415        self.C_Vectors.clear()
416        self.C_Pointers.clear()
417        self.C_DclPointers
418        if(fLine == True):
419            self.c_proc.insert(start_var, "\n")
420        return
421        s = ''
422        for n in range(len(self.C_Vars)):
423            s += str(self.C_Vars[n])
424            if n < len(self.C_Vars) - 1:
425                s += ", "
426        if(len(s) > 0):
427            self.c_proc.insert(start_var, "    double " + s + ";\n")
428            self.c_proc.insert(start_var + 1, "\n")
429
430    def writeInclude(self):
431        if(self.MathIncludeed == False):
432            self.add_c_line("#include <math.h>\n")
433            self.add_c_line("static double pi = 3.14159265359;\n")
434            self.MathIncludeed = True
435
436    def ListToString(self, strings):
437        s = ''
438        for n in range(len(strings)):
439            s += strings[n]
440            if(n < (len(strings) - 1)):
441                s += ", "
442        return(s)
443
444    def getMethodSignature(self):
445#        args_str = ListToString(self.arguments)
446        args_str = ''
447        for n in range(len(self.arguments)):
448            args_str += "double " + self.arguments[n]
449            if(n < (len(self.arguments) - 1)):
450                args_str += ", "
451        return(args_str)
452#        self.strMethodSignature = 'double ' + self.name + '(' + args_str + ")"
453
454    def InsertSignature(self):
455        args_str = ''
456        for n in range(len(self.arguments)):
457            args_str += "double " + self.arguments[n]
458            if(self.arguments[n] in self.C_Pointers):
459                args_str += "[]"
460            if(n < (len(self.arguments) - 1)):
461                args_str += ", "
462        self.strMethodSignature = 'double ' + self.name + '(' + args_str + ")"
463        if(self.signature_line >= 0):
464            self.c_proc.insert(self.signature_line, self.strMethodSignature)
465
466    def visit_FunctionDef(self, node):
467        self.newline(extra=1)
468        self.decorators(node)
469        self.newline(node)
470        self.arguments = []
471        self.name = node.name
472#        if self.name not in self.required_functions[0]:
473#           return
474        print("Parsing '" + self.name + "'")
475        args_str = ""
476
477        self.visit(node.args)
478# for C
479#        self.writeInclude()
480        self.getMethodSignature()
481# for C
482        self.signature_line = len(self.c_proc)
483#        self.add_c_line(self.strMethodSignature)
484        self.add_c_line("\n{")
485        start_vars = len(self.c_proc) + 1
486        self.body(node.body)
487        self.add_c_line("}\n")
488        self.InsertSignature()
489        self.insert_C_Vars(start_vars)
490        self.C_Pointers = []
491
492    def visit_ClassDef(self, node):
493        have_args = []
494        def paren_or_comma():
495            if have_args:
496                self.write_python(', ')
497            else:
498                have_args.append(True)
499                self.write_python('(')
500
501        self.newline(extra=2)
502        self.decorators(node)
503        self.newline(node)
504        self.write_python('class %s' % node.name)
505        for base in node.bases:
506            paren_or_comma()
507            self.visit(base)
508        # XXX: the if here is used to keep this module compatible
509        #      with python 2.6.
510        if hasattr(node, 'keywords'):
511            for keyword in node.keywords:
512                paren_or_comma()
513                self.write_python(keyword.arg + '=')
514                self.visit(keyword.value)
515            if node.starargs is not None:
516                paren_or_comma()
517                self.write_python('*')
518                self.visit(node.starargs)
519            if node.kwargs is not None:
520                paren_or_comma()
521                self.write_python('**')
522                self.visit(node.kwargs)
523        self.write_python(have_args and '):' or ':')
524        self.body(node.body)
525
526    def visit_If(self, node):
527        self.write_c('if ')
528        self.visit(node.test)
529        self.write_c(' {')
530        self.body(node.body)
531        self.add_c_line('}')
532        while True:
533            else_ = node.orelse
534            if len(else_) == 0:
535                break
536#            elif hasattr(else_, 'orelse'):
537            elif len(else_) == 1 and isinstance(else_[0], ast.If):
538                node = else_[0]
539#                self.newline()
540                self.write_c('else if ')
541                self.visit(node.test)
542                self.write_c(' {')
543                self.body(node.body)
544                self.add_current_line()
545                self.add_c_line('}')
546#                break
547            else:
548                self.newline()
549                self.write_c('else {')
550                self.body(node.body)
551                self.add_c_line('}')
552                break
553
554    def getNodeLineNo(self, node):
555        line_number = -1
556        if(hasattr(node,'value')):
557            line_number = node.value.lineno
558        elif hasattr(node,'iter'):
559            if hasattr(node.iter,'lineno'):
560                line_number = node.iter.lineno
561        return(line_number)
562
563    def GetNodeAsString(self, node):
564        res = ''
565        if(hasattr(node, 'n')):
566            res = str(node.n)
567        elif(hasattr(node, 'id')):
568            res = node.id
569        return(res)
570
571    def GetForRange(self, node):
572        stop = ""
573        start = '0'
574        step = '1'
575        for_args = []
576        temp_statement = self.current_statement
577        self.current_statement = ''
578        for arg in node.iter.args:
579            self.visit(arg)
580            for_args.append(self.current_statement)
581            self.current_statement = ''
582        self.current_statement = temp_statement
583        if(len(for_args) == 1):
584            stop = for_args[0]
585        elif(len(for_args) == 2):
586            start = for_args[0]
587            stop = for_args[1]
588        elif(len(for_args) == 3):
589            start = for_args[0]
590            stop = for_args[1]
591            start = for_args[2]
592        else:
593            raise("Ilegal for loop parameters")
594        return(start, stop, step)
595
596    def visit_For(self, node):
597# node: for iterator is stored in node.target.
598# Iterator name is in node.target.id.
599        self.add_current_line()
600#        if(len(self.current_statement) > 0):
601#            self.add_c_line(self.current_statement)
602#            self.current_statement = ''
603        fForDone = False
604        self.current_statement = ''
605        if(hasattr(node.iter, 'func')):
606            if(hasattr(node.iter.func, 'id')):
607                if(node.iter.func.id == 'range'):
608                    self.visit(node.target)
609                    iterator = self.current_statement
610                    self.current_statement = ''
611                    if(iterator not in self.C_IntVars):
612                        self.C_IntVars.append(iterator)
613                    start, stop, step = self.GetForRange(node)
614                    self.write_c("for(" + iterator + "=" + str(start) + \
615                                  " ; " + iterator + " < " + str(stop) + \
616                                  " ; " + iterator + " += " + str(step) + ") {")
617                    self.body_or_else(node)
618                    self.write_c("}")
619                    fForDone = True
620        if(fForDone == False):
621            line_number = self.getNodeLineNo(node)
622            self.current_statement = ''
623            self.write_c('for ')
624            self.visit(node.target)
625            self.write_c(' in ')
626            self.visit(node.iter)
627            self.write_c(':')
628            errStr = "Conversion Error in function " + self.name + ", Line #" + str(line_number)
629            errStr += "\nPython for expression not supported: '" + self.current_statement + "'"
630            raise Exception(errStr)
631
632    def visit_While(self, node):
633        self.newline(node)
634        self.write_c('while ')
635        self.visit(node.test)
636        self.write_c(':')
637        self.body_or_else(node)
638
639    def visit_With(self, node):
640        self.newline(node)
641        self.write_python('with ')
642        self.visit(node.context_expr)
643        if node.optional_vars is not None:
644            self.write_python(' as ')
645            self.visit(node.optional_vars)
646        self.write_python(':')
647        self.body(node.body)
648
649    def visit_Pass(self, node):
650        self.newline(node)
651        self.write_python('pass')
652
653    def visit_Print(self, node):
654# XXX: python 2.6 only
655        self.newline(node)
656        self.write_c('print ')
657        want_comma = False
658        if node.dest is not None:
659            self.write_c(' >> ')
660            self.visit(node.dest)
661            want_comma = True
662        for value in node.values:
663            if want_comma:
664                self.write_c(', ')
665            self.visit(value)
666            want_comma = True
667        if not node.nl:
668            self.write_c(',')
669
670    def visit_Delete(self, node):
671        self.newline(node)
672        self.write_python('del ')
673        for idx, target in enumerate(node):
674            if idx:
675                self.write_python(', ')
676            self.visit(target)
677
678    def visit_TryExcept(self, node):
679        self.newline(node)
680        self.write_python('try:')
681        self.body(node.body)
682        for handler in node.handlers:
683            self.visit(handler)
684
685    def visit_TryFinally(self, node):
686        self.newline(node)
687        self.write_python('try:')
688        self.body(node.body)
689        self.newline(node)
690        self.write_python('finally:')
691        self.body(node.finalbody)
692
693    def visit_Global(self, node):
694        self.newline(node)
695        self.write_python('global ' + ', '.join(node.names))
696
697    def visit_Nonlocal(self, node):
698        self.newline(node)
699        self.write_python('nonlocal ' + ', '.join(node.names))
700
701    def visit_Return(self, node):
702        self.newline(node)
703        if node.value is None:
704            self.write_c('return')
705        else:
706            self.write_c('return(')
707            self.visit(node.value)
708        self.write_c(')')
709        self.add_semi_colon()
710        self.add_c_line(self.current_statement)
711        self.current_statement = ''
712
713    def visit_Break(self, node):
714        self.newline(node)
715        self.write_c('break')
716
717    def visit_Continue(self, node):
718        self.newline(node)
719        self.write_c('continue')
720
721    def visit_Raise(self, node):
722        # XXX: Python 2.6 / 3.0 compatibility
723        self.newline(node)
724        self.write_python('raise')
725        if hasattr(node, 'exc') and node.exc is not None:
726            self.write_python(' ')
727            self.visit(node.exc)
728            if node.cause is not None:
729                self.write_python(' from ')
730                self.visit(node.cause)
731        elif hasattr(node, 'type') and node.type is not None:
732            self.visit(node.type)
733            if node.inst is not None:
734                self.write_python(', ')
735                self.visit(node.inst)
736            if node.tback is not None:
737                self.write_python(', ')
738                self.visit(node.tback)
739
740    # Expressions
741
742    def visit_Attribute(self, node):
743        errStr = "Conversion Error in function " + self.name + ", Line #" + str(node.value.lineno)
744        errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'"
745        raise Exception(errStr)
746        self.visit(node.value)
747        self.write_python('.' + node.attr)
748
749    def visit_Call(self, node):
750        want_comma = []
751        def write_comma():
752            if want_comma:
753                self.write_c(', ')
754            else:
755                want_comma.append(True)
756        if(hasattr(node.func, 'id')):
757            if(node.func.id not in self.C_Functions):
758                self.C_Functions.append(node.func.id)
759            if(node.func.id == 'abs'):
760                self.write_c("fabs ")
761            elif(node.func.id == 'int'):
762                self.write_c('(int) ')
763            elif(node.func.id == "SINCOS"):
764                self.WriteSincos(node)
765                return
766            else:
767                self.visit(node.func)
768        else:
769            self.visit(node.func)
770#self.C_Functions
771        self.write_c('(')
772        for arg in node.args:
773            write_comma()
774            self.visited_args = True
775            self.visit(arg)
776        for keyword in node.keywords:
777            write_comma()
778            self.write_c(keyword.arg + '=')
779            self.visit(keyword.value)
780        if hasattr(node, 'starargs'):
781            if node.starargs is not None:
782                write_comma()
783                self.write_c('*')
784                self.visit(node.starargs)
785        if hasattr(node, 'kwargs'):
786            if node.kwargs is not None:
787                write_comma()
788                self.write_c('**')
789                self.visit(node.kwargs)
790        self.write_c(');')
791
792    def visit_Name(self, node):
793        self.write_c(node.id)
794        if((node.id in self.C_Pointers) and(not self.SubRef)):
795            self.write_c("[0]")
796        name = ""
797        sub = node.id.find("[")
798        if(sub > 0):
799            name = node.id[0:sub].strip()
800        else:
801            name = node.id
802#       add variable to C_Vars if it ins't there yet, not an argument and not a number
803        if ((name not in self.C_Functions) and (name not in self.C_Vars) and \
804            (name not in self.C_IntVars) and (name not in self.arguments) and \
805            (name not in self.C_Constants) and (name.isnumeric() == False)):
806            if(self.InSubscript):
807                self.C_IntVars.append(node.id)
808            else:
809                self.C_Vars.append(node.id)
810
811    def visit_Str(self, node):
812        self.write_c(repr(node.s))
813
814    def visit_Bytes(self, node):
815        self.write_c(repr(node.s))
816
817    def visit_Num(self, node):
818        self.write_c(repr(node.n))
819
820    def visit_Tuple(self, node):
821        for idx, item in enumerate(node.elts):
822            if idx:
823                self.Tuples.append(item)
824            else:
825                self.visit(item)
826
827    def sequence_visit(left, right):
828        def visit(self, node):
829            self.is_sequence = True
830            s = ""
831            for idx, item in enumerate(node.elts):
832                if((idx > 0) and(len(s) > 0)):
833                    s += ', '
834                if(hasattr(item, 'id')):
835                    s += item.id
836                elif(hasattr(item, 'n')):
837                    s += str(item.n)
838            if(len(s) > 0):
839                self.C_Vectors.append(s)
840                vec_name = "vec"  + str(len(self.C_Vectors))
841                self.write_c(vec_name)
842                vec_name += "#"
843        return visit
844
845    visit_List = sequence_visit('[', ']')
846    visit_Set = sequence_visit('{', '}')
847    del sequence_visit
848
849    def visit_Dict(self, node):
850        self.write_python('{')
851        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
852            if idx:
853                self.write_python(', ')
854            self.visit(key)
855            self.write_python(': ')
856            self.visit(value)
857        self.write_python('}')
858
859    def get_special_power(self, string):
860        function_name = ''
861        is_negative_exp = False
862        if(isevaluable(str(self.current_statement))):
863            exponent = eval(string)
864            is_negative_exp = exponent < 0
865            abs_exponent = abs(exponent)
866            if(abs_exponent == 2):
867                function_name = "square"
868            elif(abs_exponent == 3):
869                function_name = "cube"
870            elif(abs_exponent == 0.5):
871                function_name = "sqrt"
872            elif(abs_exponent == 1.0/3.0):
873                function_name = "cbrt"
874        if(function_name == ''):
875            function_name = "pow"
876        return function_name, is_negative_exp
877
878    def translate_power(self, node):
879# get exponent by visiting the right hand argument.
880        function_name = "pow"
881        temp_statement = self.current_statement
882# 'visit' functions write the results to the 'current_statement' class memnber
883# Here, a temporary variable, 'temp_statement', is used, that enables the
884# use of the 'visit' function
885        self.current_statement = ''
886        self.visit(node.right)
887        exponent = self.current_statement.replace(' ', '')
888        function_name, is_negative_exp = self.get_special_power(self.current_statement)
889        self.current_statement = temp_statement
890        if(is_negative_exp):
891            self.write_c("1.0 /(")
892        self.write_c(function_name + "(")
893        self.visit(node.left)
894        if(function_name == "pow"):
895            self.write_c(", ")
896            self.visit(node.right)
897        self.write_c(")")
898        if(is_negative_exp):
899            self.write_c(")")
900        self.write_c(" ")
901
902    def translate_integer_divide(self, node):
903        self.write_c("(int)(")
904        self.visit(node.left)
905        self.write_c(") /(int)(")
906        self.visit(node.right)
907        self.write_c(")")
908
909    def visit_BinOp(self, node):
910        self.write_c("(")
911        if('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]):
912            self.translate_power(node)
913        elif('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]):
914            self.translate_integer_divide(node)
915        else:
916            self.visit(node.left)
917            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
918            self.visit(node.right)
919        self.write_c(")")
920
921#       for C
922    def visit_BoolOp(self, node):
923        self.write_c('(')
924        for idx, value in enumerate(node.values):
925            if idx:
926                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
927            self.visit(value)
928        self.write_c(')')
929
930    def visit_Compare(self, node):
931        self.write_c('(')
932        self.visit(node.left)
933        for op, right in zip(node.ops, node.comparators):
934            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
935            self.visit(right)
936        self.write_c(')')
937
938    def visit_UnaryOp(self, node):
939        self.write_c('(')
940        op = UNARYOP_SYMBOLS[type(node.op)]
941        self.write_c(op)
942        if op == 'not':
943            self.write_c(' ')
944        self.visit(node.operand)
945        self.write_c(')')
946
947    def visit_Subscript(self, node):
948        if (node.value.id not in self.C_Constants):
949            if(node.value.id not in self.C_Pointers):
950                self.C_Pointers.append(node.value.id)
951        self.SubRef = True
952        self.visit(node.value)
953        self.SubRef = False
954        self.write_c('[')
955        self.InSubscript = True
956        self.visit(node.slice)
957        self.InSubscript = False
958        self.write_c(']')
959
960    def visit_Slice(self, node):
961        if node.lower is not None:
962            self.visit(node.lower)
963        self.write_python(':')
964        if node.upper is not None:
965            self.visit(node.upper)
966        if node.step is not None:
967            self.write_python(':')
968            if not(isinstance(node.step, Name) and node.step.id == 'None'):
969                self.visit(node.step)
970
971    def visit_ExtSlice(self, node):
972        for idx, item in node.dims:
973            if idx:
974                self.write_python(', ')
975            self.visit(item)
976
977    def visit_Yield(self, node):
978        self.write_python('yield ')
979        self.visit(node.value)
980
981    def visit_Lambda(self, node):
982        self.write_python('lambda ')
983        self.visit(node.args)
984        self.write_python(': ')
985        self.visit(node.body)
986
987    def visit_Ellipsis(self, node):
988        self.write_python('Ellipsis')
989
990    def generator_visit(left, right):
991        def visit(self, node):
992            self.write_python(left)
993            self.write_c(left)
994            self.visit(node.elt)
995            for comprehension in node.generators:
996                self.visit(comprehension)
997            self.write_c(right)
998#            self.write_python(right)
999        return visit
1000
1001    visit_ListComp = generator_visit('[', ']')
1002    visit_GeneratorExp = generator_visit('(', ')')
1003    visit_SetComp = generator_visit('{', '}')
1004    del generator_visit
1005
1006    def visit_DictComp(self, node):
1007        self.write_python('{')
1008        self.visit(node.key)
1009        self.write_python(': ')
1010        self.visit(node.value)
1011        for comprehension in node.generators:
1012            self.visit(comprehension)
1013        self.write_python('}')
1014
1015    def visit_IfExp(self, node):
1016        self.visit(node.body)
1017        self.write_c(' if ')
1018        self.visit(node.test)
1019        self.write_c(' else ')
1020        self.visit(node.orelse)
1021
1022    def visit_Starred(self, node):
1023        self.write_c('*')
1024        self.visit(node.value)
1025
1026    def visit_Repr(self, node):
1027        # XXX: python 2.6 only
1028        self.write_c('`')
1029        self.visit(node.value)
1030        self.write_python('`')
1031
1032    # Helper Nodes
1033
1034    def visit_alias(self, node):
1035        self.write_python(node.name)
1036        if node.asname is not None:
1037            self.write_python(' as ' + node.asname)
1038
1039    def visit_comprehension(self, node):
1040        self.write_c(' for ')
1041        self.visit(node.target)
1042        self.write_C(' in ')
1043#        self.write_python(' in ')
1044        self.visit(node.iter)
1045        if node.ifs:
1046            for if_ in node.ifs:
1047                self.write_python(' if ')
1048                self.visit(if_)
1049
1050#    def visit_excepthandler(self, node):
1051#        self.newline(node)
1052#        self.write_python('except')
1053#        if node.type is not None:
1054#            self.write_python(' ')
1055#            self.visit(node.type)
1056#            if node.name is not None:
1057#                self.write_python(' as ')
1058#                self.visit(node.name)
1059#        self.body(node.body)
1060
1061    def visit_arguments(self, node):
1062        self.signature(node)
1063
1064def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1065    z1 = z2 = z = abs(q - peak_pos) * lorentz_length
1066    if(q > p):
1067        q = p + 17
1068        p = q - 5
1069    z3 = -8
1070    inten = (porod_scale / q ** porod_exp
1071                + lorentz_scale /(1 + z ** lorentz_exp))
1072    return inten
1073
1074def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
1075    z1 = z2 = z = abs(q - peak_pos) * lorentz_length
1076    if(q > p):
1077        q = p + 17
1078        p = q - 5
1079    elif(q == p):
1080        q = p * q
1081        q *= z1
1082        p = z1
1083    elif(q == 17):
1084        q = p * q - 17
1085    else:
1086        q += 7
1087    z3 = -8
1088    inten = (porod_scale / q ** porod_exp
1089                + lorentz_scale /(1 + z ** lorentz_exp))
1090    return inten
1091
1092def print_function(f=None):
1093    """
1094    Print out the code for the function
1095    """
1096    # Include some comments to see if they get printed
1097    import ast
1098    import inspect
1099    if f is not None:
1100        tree = ast.parse(inspect.getsource(f))
1101        tree_source = to_source(tree)
1102        print(tree_source)
1103
1104def add_constants (sniplets, c_constants):
1105    sniplets.append("#include <math.h>")
1106    sniplets.append("")
1107    vars = c_constants.keys()
1108    for c_var in vars:
1109        c_values = c_constants[c_var]
1110        declare_values = str(c_values)
1111        str_dcl = "double " + c_var
1112        if (hasattr(c_values,'__len__')):
1113            str_dcl += "[]"
1114            len_prev = len(declare_values)
1115            len_after = len_prev - 1
1116            declare_values = declare_values.replace ('[','')
1117            declare_values = declare_values.replace (']','').strip()
1118            while (len_after < len_prev):
1119                len_prev = len_after
1120                declare_values = declare_values.replace ('  ',' ')
1121                len_after = len(declare_values)
1122            declare_values = "{" + declare_values.replace (' ',',') + "}"
1123        str_dcl += " = " + declare_values + ";"
1124        sniplets.append (str_dcl)
1125        sniplets.append("")
1126
1127def translate(functions, constants=0):
1128    sniplets = []
1129#    sniplets.append("#include <math.h>")
1130#    sniplets.append("static double pi = 3.14159265359;")
1131    add_constants (sniplets, constants)
1132    for source,fname,line_no in functions:
1133        line_directive = '#line %d "%s"' %(line_no,fname)
1134        line_directive = line_directive.replace('\\','\\\\')
1135#        sniplets.append(line_directive)
1136        tree = ast.parse(source)
1137        sniplet = to_source(tree, functions, constants) # in the future add filename, offset, constants
1138        sniplets.append(sniplet)
1139    c_code = "\n".join(sniplets)
1140    f_out = open ("xlate.c", "w+")
1141    f_out.write (c_code)
1142    f_out.close()
1143    return("\n".join(sniplets))
1144
1145def get_file_names():
1146    fname_in = ""
1147    fname_out = ""
1148    if(len(sys.argv) > 1):
1149        fname_in = sys.argv[1]
1150        fname_base = os.path.splitext(fname_in)
1151        if(len(sys.argv) == 2):
1152            fname_out = str(fname_base[0]) + '.c'
1153        else:
1154            fname_out = sys.argv[2]
1155        if(len(fname_in) > 0):
1156            python_file = open(sys.argv[1], "r")
1157            if(len(fname_out) > 0):
1158                file_out = open(fname_out, "w+")
1159    return len(sys.argv), fname_in, fname_out
1160
1161if __name__ == "__main__":
1162    import os
1163    print("Parsing...using Python" + sys.version)
1164    try:
1165        fname_in = ""
1166        fname_out = ""
1167        if(len(sys.argv) == 1):
1168            print("Usage:\npython parse01.py <infile> [<outfile>](if omitted, output file is '<infile>.c'")
1169        else:
1170            fname_in = sys.argv[1]
1171            fname_base = os.path.splitext(fname_in)
1172            if(len(sys.argv) == 2):
1173                fname_out = str(fname_base[0]) + '.c'
1174            else:
1175                fname_out = sys.argv[2]
1176            if(len(fname_in) > 0):
1177                python_file = open(sys.argv[1], "r")
1178                if(len(fname_out) > 0):
1179                    file_out = open(fname_out, "w+")
1180                functions = ["MultAsgn", "Iq41", "Iq2"]
1181                tpls = [functions, fname_in, 0]
1182                c_txt = translate(tpls)
1183                file_out.write(c_txt)
1184                file_out.close()
1185    except Exception as excp:
1186        print("Error:\n" + str(excp.args))
1187    print("...Done")
Note: See TracBrowser for help on using the repository browser.