source: sasmodels/sasmodels/py2c.py @ 71779b2

Last change on this file since 71779b2 was 71779b2, checked in by Omer Eisenberg <omereis@…>, 6 years ago

Supporting multiple assignment, SINCOS, power, array initiation

  • Property mode set to 100644
File size: 35.4 KB
Line 
1
2"""
3    codegen
4    ~~~~~~~
5
6    Extension to ast that allow ast -> python code generation.
7
8    :copyright: Copyright 2008 by Armin Ronacher.
9    :license: BSD.
10"""
11"""
12    Variables definition in C
13    -------------------------
14    Defining variables within the Translate function is a bit of a guess work, using following rules.
15    *   By default, a variable is a 'double'.
16    *   Variable in a for loop is an int.
17    *   Variable that is references with brackets is an array of doubles. The variable within the brackets
18        is integer. For example, in the reference 'var1[var2]', var1 is a double array, and var2 is an integer.
19    *   Assignment to an argument makes that argument an array, and the index in that assignment is 0.
20        For example, the following python code
21            def func (arg1, arg2):
22                arg2 = 17.
23        is translated to the following C code
24            double func (double arg1)
25            {
26                arg2[0] = 17.0;
27            }
28        For example, the following python code is translated to the following C code
29            def func (arg1, arg2):          double func (double arg1) {
30                arg2 = 17.                      arg2[0] = 17.0;
31                                            }
32    *   All functions are defined as double, even if there is no return statement.
33
34
35Update Notes
36============
3711/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It shold insert 4 blanks per indentation level.
38                    The 'body' method will combine all the strings, by adding the 'current_statement' to the c_proc string list
39   11/2017, OE: variables, argument definition implemented.
40   Note: An argument is considered an array if it is the target of an assignment. In that case it is translated to <var>[0]
4111/27/2017, OE: 'pow' basicly working
42  /12/2017, OE: Multiple assignment: a1,a2,...,an=b1,b2,...bn implemented
43  /12/2017, OE: Power function, including special cases of square(x) (pow(x,2)) and cube(x) (pow(x,3)), implemented in translate_power,
44                called from visit_BinOp
4512/07/2017, OE: Translation of integer division, '\\' in python, implemented in translate_integer_divide, called from visit_BinOp
4612/07/2017, OE: C variable definition handled in 'define_C_Vars'
47"""
48import ast
49import sys
50from ast import NodeVisitor
51
52BINOP_SYMBOLS = {}
53BINOP_SYMBOLS[ast.Add] = '+'
54BINOP_SYMBOLS[ast.Sub] = '-'
55BINOP_SYMBOLS[ast.Mult] = '*'
56BINOP_SYMBOLS[ast.Div] = '/'
57BINOP_SYMBOLS[ast.Mod] = '%'
58BINOP_SYMBOLS[ast.Pow] = '**'
59BINOP_SYMBOLS[ast.LShift] = '<<'
60BINOP_SYMBOLS[ast.RShift] = '>>'
61BINOP_SYMBOLS[ast.BitOr] = '|'
62BINOP_SYMBOLS[ast.BitXor] = '^'
63BINOP_SYMBOLS[ast.BitAnd] = '&'
64BINOP_SYMBOLS[ast.FloorDiv] = '//'
65
66BOOLOP_SYMBOLS = {}
67BOOLOP_SYMBOLS[ast.And] = '&&'
68BOOLOP_SYMBOLS[ast.Or]  = '||'
69
70CMPOP_SYMBOLS = {}
71CMPOP_SYMBOLS[ast.Eq] = '=='
72CMPOP_SYMBOLS[ast.NotEq] = '!='
73CMPOP_SYMBOLS[ast.Lt] = '<'
74CMPOP_SYMBOLS[ast.LtE] = '<='
75CMPOP_SYMBOLS[ast.Gt] = '>'
76CMPOP_SYMBOLS[ast.GtE] = '>='
77CMPOP_SYMBOLS[ast.Is] = 'is'
78CMPOP_SYMBOLS[ast.IsNot] = 'is not'
79CMPOP_SYMBOLS[ast.In] = 'in'
80CMPOP_SYMBOLS[ast.NotIn] = 'not in'
81
82UNARYOP_SYMBOLS = {}
83UNARYOP_SYMBOLS[ast.Invert] = '~'
84UNARYOP_SYMBOLS[ast.Not] = 'not'
85UNARYOP_SYMBOLS[ast.UAdd] = '+'
86UNARYOP_SYMBOLS[ast.USub] = '-'
87
88
89#def to_source(node, indent_with=' ' * 4, add_line_information=False):
90def to_source(node, func_name):
91    """This function can convert a node tree back into python sourcecode.
92    This is useful for debugging purposes, especially if you're dealing with
93    custom asts not generated by python itself.
94
95    It could be that the sourcecode is evaluable when the AST itself is not
96    compilable / evaluable.  The reason for this is that the AST contains some
97    more data than regular sourcecode does, which is dropped during
98    conversion.
99
100    Each level of indentation is replaced with `indent_with`.  Per default this
101    parameter is equal to four spaces as suggested by PEP 8, but it might be
102    adjusted to match the application's styleguide.
103
104    If `add_line_information` is set to `True` comments for the line numbers
105    of the nodes are added to the output.  This can be used to spot wrong line
106    number information of statement nodes.
107    """
108    generator = SourceGenerator(' ' * 4, False)
109    generator.required_functions = func_name
110    generator.visit(node)
111
112#    return ''.join(generator.result)
113    return ''.join(generator.c_proc)
114
115class SourceGenerator(NodeVisitor):
116    """This visitor is able to transform a well formed syntax tree into python
117    sourcecode.  For more details have a look at the docstring of the
118    `node_to_source` function.
119    """
120
121    def __init__(self, indent_with, add_line_information=False):
122        self.result = []
123        self.indent_with = indent_with
124        self.add_line_information = add_line_information
125        self.indentation = 0
126        self.new_lines = 0
127        self.c_proc = []
128# for C
129        self.signature_line = 0
130        self.arguments = []
131        self.name = ""
132        self.warnings = []
133        self.Statements = []
134        self.current_statement = ""
135        self.strMethodSignature = ""
136        self.C_Vars = []
137        self.C_IntVars = []
138        self.MathIncludeed = False
139        self.C_Pointers = []
140        self.C_DclPointers = []
141        self.C_Functions = []
142        self.C_Vectors = []
143        self.SubRef = False
144        self.InSubscript = False
145        self.Tuples = []
146        self.required_functions = []
147        self.is_sequence = False
148        self.visited_args = False
149
150    def write_python(self, x):
151        if self.new_lines:
152            if self.result:
153                self.result.append('\n' * self.new_lines)
154            self.result.append(self.indent_with * self.indentation)
155            self.new_lines = 0
156        self.result.append(x)
157
158    def write_c(self, x):
159        self.current_statement += x
160       
161    def add_c_line(self, x):
162        string = ''
163        for i in range (self.indentation):
164            string += ("    ")
165        string += str(x)
166        self.c_proc.append (str(string + "\n"))
167        x = ''
168
169    def add_current_line (self):
170        if (len(self.current_statement) > 0):
171            self.add_c_line (self.current_statement)
172            self.current_statement = ''
173
174    def AddUniqueVar (self, new_var):
175        if ((new_var not in self.C_Vars)):
176            self.C_Vars.append (str (new_var))
177
178    def WriteSincos (self, node):
179        angle = str(node.args[0].id)
180        self.write_c (node.args[1].id + " = sin (" + angle + ");")
181        self.add_current_line()
182        self.write_c (node.args[2].id + " = cos (" + angle + ");")
183        self.add_current_line()
184        for arg in node.args:
185            self.AddUniqueVar (arg.id)
186
187    def newline(self, node=None, extra=0):
188        self.new_lines = max(self.new_lines, 1 + extra)
189        if node is not None and self.add_line_information:
190            self.write_c('# line: %s' % node.lineno)
191            self.new_lines = 1
192        if (len(self.current_statement)):
193            self.Statements.append (self.current_statement)
194            self.current_statement = ''
195
196    def body(self, statements):
197        if (len(self.current_statement)):
198            self.add_current_line ()
199        self.new_line = True
200        self.indentation += 1
201        for stmt in statements:
202            target_name = ''
203            if (hasattr (stmt, 'targets')):
204                if (hasattr(stmt.targets[0],'id')):
205                    target_name = stmt.targets[0].id # target name needed for debug only
206            self.visit(stmt)
207        self.add_current_line () # just for breaking point. to be deleted.
208        self.indentation -= 1
209
210    def body_or_else(self, node):
211        self.body(node.body)
212        if node.orelse:
213            self.newline()
214            self.write_c('else:')
215            self.body(node.orelse)
216
217    def signature(self, node):
218        want_comma = []
219        def write_comma():
220            if want_comma:
221                self.write_c(', ')
222            else:
223                want_comma.append(True)
224# for C
225        for arg in node.args:
226            self.arguments.append (arg.arg)
227
228        padding = [None] * (len(node.args) - len(node.defaults))
229        for arg, default in zip(node.args, padding + node.defaults):
230            if default is not None:
231                self.warnings.append ("Default Parameter unknown to C")
232                w_str = "Default Parameters are unknown to C: '" + arg.arg + " = " + str(default.n) + "'"
233                self.warnings.append (w_str)
234#                self.write_python('=')
235#                self.visit(default)
236
237    def decorators(self, node):
238        for decorator in node.decorator_list:
239            self.newline(decorator)
240            self.write_python('@')
241            self.visit(decorator)
242
243    # Statements
244
245    def visit_Assert(self, node):
246        self.newline(node)
247        self.write_c('assert ')
248        self.visit(node.test)
249        if node.msg is not None:
250           self.write_python(', ')
251           self.visit(node.msg)
252
253    def define_C_Vars (self, target):
254        if (hasattr (target, 'id')):
255# a variable is considered an array if it apears in the agrument list
256# and being assigned to. For example, the variable p in the following
257# sniplet is a pointer, while q is not
258# def somefunc (p, q):
259#  p = q + 1
260#  return
261#
262            if (target.id not in self.C_Vars):
263                if (target.id in self.arguments):
264                    idx = self.arguments.index (target.id)
265                    new_target = self.arguments[idx] + "[0]"
266                    if (new_target not in self.C_Pointers):
267                        target.id = new_target
268                        self.C_Pointers.append (self.arguments[idx])
269                else:
270                    self.C_Vars.append (target.id)
271
272    def visit_Assign(self, node):
273        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7'
274            if idx:
275                self.write_c(' = ')
276            self.define_C_Vars (target)
277            self.visit(target)
278        if (len(self.Tuples) > 0):
279            tplTargets = list (self.Tuples)
280            self.Tuples.clear()
281        self.write_c(' = ')
282        self.is_sequence = False
283        self.visited_args = False
284        self.visit(node.value)
285        self.write_c(';')
286        self.add_current_line ()
287        for n,item in enumerate (self.Tuples):
288            self.visit(tplTargets[n])
289            self.write_c(' = ')
290            self.visit(item)
291            self.write_c(';')
292            self.add_current_line ()
293        if ((self.is_sequence) and (not self.visited_args)):
294            for target  in node.targets:
295                if (hasattr (target, 'id')):
296                    if ((target.id in self.C_Vars) and (target.id not in self.C_DclPointers)):
297                        if (target.id not in self.C_DclPointers):
298                            self.C_DclPointers.append (target.id)
299                            if (target.id in self.C_Vars):
300                                self.C_Vars.remove(target.id)
301        self.current_statement = ''
302
303    def visit_AugAssign(self, node):
304        if (node.target.id not in self.C_Vars):
305            if (node.target.id not in self.arguments):
306                self.C_Vars.append (node.target.id)
307        self.visit(node.target)
308        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ')
309        self.visit(node.value)
310        self.write_c(';')
311        self.add_current_line ()
312
313    def visit_ImportFrom(self, node):
314        self.newline(node)
315        self.write_python('from %s%s import ' % ('.' * node.level, node.module))
316        for idx, item in enumerate(node.names):
317            if idx:
318                self.write_python(', ')
319            self.write_python(item)
320
321    def visit_Import(self, node):
322        self.newline(node)
323        for item in node.names:
324            self.write_python('import ')
325            self.visit(item)
326
327    def visit_Expr(self, node):
328        self.newline(node)
329        self.generic_visit(node)
330
331    def listToDeclare(self, Vars):
332        s = ''
333        if (len(Vars) > 0):
334            s = ",".join (Vars)
335        return (s)
336
337    def insert_C_Vars (self, start_var):
338        fLine = False
339        if (len(self.C_Vars) > 0):
340            s = self.listToDeclare(self.C_Vars)
341            self.c_proc.insert (start_var, "    double " + s + ";\n")
342            fLine = True
343            start_var += 1
344        if (len(self.C_IntVars) > 0):
345            s = self.listToDeclare(self.C_IntVars)
346            self.c_proc.insert (start_var, "    int " + s + ";\n")
347            fLine = True
348            start_var += 1
349        if (len (self.C_Vectors) > 0):
350            s = self.listToDeclare(self.C_Vectors)
351            for n in range(len(self.C_Vectors)):
352                name = "vec" + str(n+1)
353                c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};"
354                self.c_proc.insert (start_var, c_dcl + "\n")
355                start_var += 1
356        if (len (self.C_Pointers) > 0):
357            vars = ""
358            for n in range(len(self.C_Pointers)):
359                if (len(vars) > 0):
360                    vars += ", "
361                if (self.C_Pointers[n] not in self.arguments):
362                    vars += "*" + self.C_Pointers[n]
363            if (len(vars) > 0):
364                c_dcl = "    double " + vars + ";"
365                self.c_proc.insert (start_var, c_dcl + "\n")
366                start_var += 1
367        if (len (self.C_DclPointers) > 0):
368            vars = ''
369            for n in range(len(self.C_DclPointers)):
370                if (len(vars) > 0):
371                    vars += ', '
372                vars += "*" + self.C_DclPointers[n]
373            if (len(vars) > 0):
374                c_dcl = "    double " + vars + ";"
375                self.c_proc.insert (start_var, c_dcl + "\n")
376                start_var += 1
377        self.C_Vars.clear()
378        self.C_IntVars.clear()
379        self.C_Vectors.clear()
380        self.C_Pointers.clear()
381        if (fLine == True):
382            self.c_proc.insert (start_var, "\n")
383        return
384        s = ''
385        for n in range(len(self.C_Vars)):
386            s += str(self.C_Vars[n])
387            if n < len(self.C_Vars) - 1:
388                s += ", "
389        if (len(s) > 0):
390            self.c_proc.insert (start_var, "    double " + s + ";\n")
391            self.c_proc.insert (start_var + 1, "\n")
392
393    def writeInclude (self):
394        if (self.MathIncludeed == False):
395            self.add_c_line("#include <math.h>\n")
396            self.add_c_line("static double pi = 3.14159265359;\n")
397            self.MathIncludeed = True
398
399    def ListToString (self, strings):
400        s = ''
401        for n in range(len(strings)):
402            s += strings[n]
403            if (n < (len(strings) - 1)):
404                s += ", "
405        return (s)
406
407    def getMethodSignature (self):
408#        args_str = ListToString (self.arguments)
409        args_str = ''
410        for n in range(len(self.arguments)):
411            args_str += "double " + self.arguments[n]
412            if (n < (len(self.arguments) - 1)):
413                args_str += ", "
414        return (args_str)
415#        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
416
417    def InsertSignature (self):
418        args_str = ''
419        for n in range(len(self.arguments)):
420            args_str += "double " + self.arguments[n]
421            if (self.arguments[n] in self.C_Pointers):
422                args_str += "[]"
423            if (n < (len(self.arguments) - 1)):
424                args_str += ", "
425        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")"
426        if (self.signature_line > 0):
427            self.c_proc.insert (self.signature_line, self.strMethodSignature)
428
429    def visit_FunctionDef(self, node):
430        self.newline(extra=1)
431        self.decorators(node)
432        self.newline(node)
433        self.arguments = []
434        self.name = node.name
435        if self.name not in self.required_functions[0]:
436            return
437        print("Parsing '" + self.name + "'")
438        args_str = ""
439
440        self.visit(node.args)
441# for C
442        self.writeInclude()
443        self.getMethodSignature ()
444# for C
445        self.signature_line = len(self.c_proc)
446#        self.add_c_line(self.strMethodSignature)
447        self.add_c_line("\n{")
448        start_vars = len(self.c_proc) + 1
449        self.body(node.body)
450        self.add_c_line("}\n")
451        self.InsertSignature ()
452        self.insert_C_Vars (start_vars)
453        self.C_Pointers = []
454
455    def visit_ClassDef(self, node):
456        have_args = []
457        def paren_or_comma():
458            if have_args:
459                self.write_python(', ')
460            else:
461                have_args.append(True)
462                self.write_python('(')
463
464        self.newline(extra=2)
465        self.decorators(node)
466        self.newline(node)
467        self.write_python('class %s' % node.name)
468        for base in node.bases:
469            paren_or_comma()
470            self.visit(base)
471        # XXX: the if here is used to keep this module compatible
472        #      with python 2.6.
473        if hasattr(node, 'keywords'):
474            for keyword in node.keywords:
475                paren_or_comma()
476                self.write_python(keyword.arg + '=')
477                self.visit(keyword.value)
478            if node.starargs is not None:
479                paren_or_comma()
480                self.write_python('*')
481                self.visit(node.starargs)
482            if node.kwargs is not None:
483                paren_or_comma()
484                self.write_python('**')
485                self.visit(node.kwargs)
486        self.write_python(have_args and '):' or ':')
487        self.body(node.body)
488
489    def visit_If(self, node):
490        self.write_c('if ')
491        self.visit(node.test)
492        self.write_c(' {')
493        self.body(node.body)
494        self.add_c_line('}')
495        while True:
496            else_ = node.orelse
497            if len(else_) == 0:
498                break
499#            elif hasattr (else_, 'orelse'):
500            elif len(else_) == 1 and isinstance(else_[0], ast.If):
501                node = else_[0]
502#                self.newline()
503                self.write_c('else if ')
504                self.visit(node.test)
505                self.write_c(' {')
506                self.body(node.body)
507                self.add_current_line ()
508                self.add_c_line('}')
509#                break
510            else:
511                self.newline()
512                self.write_c('else {')
513                self.body(node.body)
514                self.add_c_line('}')
515                break
516
517    def getNodeLineNo (self, node):
518        line_number = -1
519        if (hasattr (node,'value')):
520            line_number = node.value.lineno
521        elif hasattr(node,'iter'):
522            if hasattr(node.iter,'lineno'):
523                line_number = node.iter.lineno
524        return (line_number)
525
526    def visit_For(self, node):
527        if (len(self.current_statement) > 0):
528            self.add_c_line(self.current_statement)
529            self.current_statement = ''
530        fForDone = False
531        if (hasattr(node.iter,'func')):
532            if (hasattr (node.iter.func,'id')):
533                if (node.iter.func.id == 'range'):
534                    if ('n' not in self.C_IntVars):
535                        self.C_IntVars.append ('n')
536                    self.write_c ("for (n=0 ; n < len(")
537                    for arg in node.iter.args:
538                        self.visit(arg)
539                    self.write_c (") ; n++) {")
540                    self.body_or_else(node)
541                    self.write_c ("}")
542                    fForDone = True
543        if (fForDone == False):
544            line_number = self.getNodeLineNo (node)
545            self.current_statement = ''
546            self.write_c('for ')
547            self.visit(node.target)
548            self.write_c(' in ')
549            self.visit(node.iter)
550            self.write_c(':')
551            errStr = "Conversion Error in function " + self.name + ", Line #" + str (line_number)
552            errStr += "\nPython for expression not supported: '" + self.current_statement + "'"
553            raise Exception(errStr)
554
555    def visit_While(self, node):
556        self.newline(node)
557        self.write_c('while ')
558        self.visit(node.test)
559        self.write_c(':')
560        self.body_or_else(node)
561
562    def visit_With(self, node):
563        self.newline(node)
564        self.write_python('with ')
565        self.visit(node.context_expr)
566        if node.optional_vars is not None:
567            self.write_python(' as ')
568            self.visit(node.optional_vars)
569        self.write_python(':')
570        self.body(node.body)
571
572    def visit_Pass(self, node):
573        self.newline(node)
574        self.write_python('pass')
575
576    def visit_Print(self, node):
577# XXX: python 2.6 only
578        self.newline(node)
579        self.write_c('print ')
580        want_comma = False
581        if node.dest is not None:
582            self.write_c(' >> ')
583            self.visit(node.dest)
584            want_comma = True
585        for value in node.values:
586            if want_comma:
587                self.write_c(', ')
588            self.visit(value)
589            want_comma = True
590        if not node.nl:
591            self.write_c(',')
592
593    def visit_Delete(self, node):
594        self.newline(node)
595        self.write_python('del ')
596        for idx, target in enumerate(node):
597            if idx:
598                self.write_python(', ')
599            self.visit(target)
600
601    def visit_TryExcept(self, node):
602        self.newline(node)
603        self.write_python('try:')
604        self.body(node.body)
605        for handler in node.handlers:
606            self.visit(handler)
607
608    def visit_TryFinally(self, node):
609        self.newline(node)
610        self.write_python('try:')
611        self.body(node.body)
612        self.newline(node)
613        self.write_python('finally:')
614        self.body(node.finalbody)
615
616    def visit_Global(self, node):
617        self.newline(node)
618        self.write_python('global ' + ', '.join(node.names))
619
620    def visit_Nonlocal(self, node):
621        self.newline(node)
622        self.write_python('nonlocal ' + ', '.join(node.names))
623
624    def visit_Return(self, node):
625        self.newline(node)
626        if node.value is None:
627            self.write_c('return')
628        else:
629            self.write_c('return (')
630            self.visit(node.value)
631        self.write_c(');')
632#      self.add_current_statement (self)
633        self.add_c_line (self.current_statement)
634        self.current_statement = ''
635
636    def visit_Break(self, node):
637        self.newline(node)
638        self.write_c('break')
639
640    def visit_Continue(self, node):
641        self.newline(node)
642        self.write_c('continue')
643
644    def visit_Raise(self, node):
645        # XXX: Python 2.6 / 3.0 compatibility
646        self.newline(node)
647        self.write_python('raise')
648        if hasattr(node, 'exc') and node.exc is not None:
649            self.write_python(' ')
650            self.visit(node.exc)
651            if node.cause is not None:
652                self.write_python(' from ')
653                self.visit(node.cause)
654        elif hasattr(node, 'type') and node.type is not None:
655            self.visit(node.type)
656            if node.inst is not None:
657                self.write_python(', ')
658                self.visit(node.inst)
659            if node.tback is not None:
660                self.write_python(', ')
661                self.visit(node.tback)
662
663    # Expressions
664
665    def visit_Attribute(self, node):
666        errStr = "Conversion Error in function " + self.name + ", Line #" + str (node.value.lineno)
667        errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'"
668        raise Exception(errStr)
669        self.visit(node.value)
670        self.write_python('.' + node.attr)
671
672    def visit_Call(self, node):
673        want_comma = []
674        def write_comma():
675            if want_comma:
676                self.write_c(', ')
677            else:
678                want_comma.append(True)
679        if (hasattr (node.func, 'id')):
680            if (node.func.id not in self.C_Functions):
681                self.C_Functions.append (node.func.id)
682            if (node.func.id == 'abs'):
683                self.write_c ("fabs ")
684            elif (node.func.id == 'int'):
685                self.write_c('(int) ')
686            elif (node.func.id == "SINCOS"):
687                self.WriteSincos (node)
688                return
689            else:
690                self.visit(node.func)
691        else:
692            self.visit(node.func)
693#self.C_Functions
694        self.write_c('(')
695        for arg in node.args:
696            write_comma()
697            self.visited_args = True 
698            self.visit(arg)
699        for keyword in node.keywords:
700            write_comma()
701            self.write_c(keyword.arg + '=')
702            self.visit(keyword.value)
703        if hasattr (node, 'starargs'):
704            if node.starargs is not None:
705                write_comma()
706                self.write_c('*')
707                self.visit(node.starargs)
708        if hasattr (node, 'kwargs'):
709            if node.kwargs is not None:
710                write_comma()
711                self.write_c('**')
712                self.visit(node.kwargs)
713        self.write_c(')')
714
715    def visit_Name(self, node):
716        self.write_c(node.id)
717        if ((node.id in self.C_Pointers) and (not self.SubRef)):
718            self.write_c("[0]")
719        name = ""
720        sub = node.id.find("[")
721        if (sub > 0):
722            name = node.id[0:sub].strip()
723        else:
724            name = node.id
725#       add variable to C_Vars if it ins't there yet, not an argument and not a number
726        if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)):
727            if (self.InSubscript):
728                self.C_IntVars.append (node.id)
729            else:
730                self.C_Vars.append (node.id)
731
732    def visit_Str(self, node):
733        self.write_c(repr(node.s))
734
735    def visit_Bytes(self, node):
736        self.write_c(repr(node.s))
737
738    def visit_Num(self, node):
739        self.write_c(repr(node.n))
740
741    def visit_Tuple(self, node):
742        for idx, item in enumerate(node.elts):
743            if idx:
744                self.Tuples.append(item)
745            else:
746                self.visit(item)
747
748    def sequence_visit(left, right):
749        def visit(self, node):
750            self.is_sequence = True
751            s = ""
752            for idx, item in enumerate(node.elts):
753                if ((idx > 0) and (len(s) > 0)):
754                    s += ', '
755                if (hasattr (item, 'id')):
756                    s += item.id
757                elif (hasattr(item,'n')):
758                    s += str(item.n)
759            if (len(s) > 0):
760                self.C_Vectors.append (s)
761                vec_name = "vec"  + str(len(self.C_Vectors))
762                self.write_c(vec_name)
763                vec_name += "#"
764        return visit
765
766    visit_List = sequence_visit('[', ']')
767    visit_Set = sequence_visit('{', '}')
768    del sequence_visit
769
770    def visit_Dict(self, node):
771        self.write_python('{')
772        for idx, (key, value) in enumerate(zip(node.keys, node.values)):
773            if idx:
774                self.write_python(', ')
775            self.visit(key)
776            self.write_python(': ')
777            self.visit(value)
778        self.write_python('}')
779
780    def translate_power (self, node):
781# get exponent by visiting the right hand argument.
782        function_name = "pow"
783        temp_statement = self.current_statement
784        self.current_statement = ''
785        self.visit(node.right)
786        exponent = self.current_statement.replace(' ','')
787        exponent = exponent.replace('(','')
788        exponent = exponent.replace(')','')
789        self.current_statement = temp_statement
790# is the right hand argument, the exponent, a number?
791        if (hasattr(node.right, 'n')): # power of constand
792            exponent = node.right.n
793            if (exponent == 2):
794                function_name = "square"
795            elif (exponent == 3):
796                function_name = "cube"
797            elif (exponent == 0.5):
798                function_name = "sqrt"
799        elif (exponent == "1/2"):
800                function_name = "sqrt"
801        self.write_c (function_name + " (")
802        self.visit(node.left)
803        if (function_name == "pow"):
804            self.write_c(", ")
805            self.visit(node.right)
806        self.write_c(")")
807
808    def translate_integer_divide (self, node):
809        self.write_c ("(int) (")
810        self.visit(node.left)
811        self.write_c (") / (int) (")
812        self.visit(node.right)
813        self.write_c (")")
814
815    def visit_BinOp(self, node):
816        if ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]):
817            self.translate_power (node)
818        elif ('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]):
819            self.translate_integer_divide (node)
820        else:
821            self.visit(node.left)
822            self.write_c(' %s ' % BINOP_SYMBOLS[type(node.op)])
823            self.visit(node.right)
824#       for C
825    def visit_BoolOp(self, node):
826        self.write_c('(')
827        for idx, value in enumerate(node.values):
828            if idx:
829                self.write_c(' %s ' % BOOLOP_SYMBOLS[type(node.op)])
830            self.visit(value)
831        self.write_c(')')
832
833    def visit_Compare(self, node):
834        self.write_c('(')
835        self.visit(node.left)
836        for op, right in zip(node.ops, node.comparators):
837            self.write_c(' %s ' % CMPOP_SYMBOLS[type(op)])
838            self.visit(right)
839        self.write_c(')')
840
841    def visit_UnaryOp(self, node):
842        self.write_c('(')
843        op = UNARYOP_SYMBOLS[type(node.op)]
844        self.write_c(op)
845        if op == 'not':
846            self.write_c(' ')
847        self.visit(node.operand)
848        self.write_c(')')
849
850    def visit_Subscript(self, node):
851        if (node.value.id not in self.C_Pointers):
852            self.C_Pointers.append (node.value.id)
853        self.SubRef = True
854        self.visit(node.value)
855        self.SubRef = False
856        self.write_c('[')
857        self.InSubscript = True
858        self.visit(node.slice)
859        self.InSubscript = False
860        self.write_c(']')
861
862    def visit_Slice(self, node):
863        if node.lower is not None:
864            self.visit(node.lower)
865        self.write_python(':')
866        if node.upper is not None:
867            self.visit(node.upper)
868        if node.step is not None:
869            self.write_python(':')
870            if not (isinstance(node.step, Name) and node.step.id == 'None'):
871                self.visit(node.step)
872
873    def visit_ExtSlice(self, node):
874        for idx, item in node.dims:
875            if idx:
876                self.write_python(', ')
877            self.visit(item)
878
879    def visit_Yield(self, node):
880        self.write_python('yield ')
881        self.visit(node.value)
882
883    def visit_Lambda(self, node):
884        self.write_python('lambda ')
885        self.visit(node.args)
886        self.write_python(': ')
887        self.visit(node.body)
888
889    def visit_Ellipsis(self, node):
890        self.write_python('Ellipsis')
891
892    def generator_visit(left, right):
893        def visit(self, node):
894            self.write_python(left)
895            self.visit(node.elt)
896            for comprehension in node.generators:
897                self.visit(comprehension)
898            self.write_python(right)
899        return visit
900
901    visit_ListComp = generator_visit('[', ']')
902    visit_GeneratorExp = generator_visit('(', ')')
903    visit_SetComp = generator_visit('{', '}')
904    del generator_visit
905
906    def visit_DictComp(self, node):
907        self.write_python('{')
908        self.visit(node.key)
909        self.write_python(': ')
910        self.visit(node.value)
911        for comprehension in node.generators:
912            self.visit(comprehension)
913        self.write_python('}')
914
915    def visit_IfExp(self, node):
916        self.visit(node.body)
917        self.write_c(' if ')
918        self.visit(node.test)
919        self.write_c(' else ')
920        self.visit(node.orelse)
921
922    def visit_Starred(self, node):
923        self.write_c('*')
924        self.visit(node.value)
925
926    def visit_Repr(self, node):
927        # XXX: python 2.6 only
928        self.write_c('`')
929        self.visit(node.value)
930        self.write_python('`')
931
932    # Helper Nodes
933
934    def visit_alias(self, node):
935        self.write_python(node.name)
936        if node.asname is not None:
937            self.write_python(' as ' + node.asname)
938
939    def visit_comprehension(self, node):
940        self.write_c(' for ')
941        self.visit(node.target)
942        self.write_python(' in ')
943        self.visit(node.iter)
944        if node.ifs:
945            for if_ in node.ifs:
946                self.write_python(' if ')
947                self.visit(if_)
948
949#    def visit_excepthandler(self, node):
950#        self.newline(node)
951#        self.write_python('except')
952#        if node.type is not None:
953#            self.write_python(' ')
954#            self.visit(node.type)
955#            if node.name is not None:
956#                self.write_python(' as ')
957#                self.visit(node.name)
958#        self.body(node.body)
959
960    def visit_arguments(self, node):
961        self.signature(node)
962
963def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
964    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
965    if (q > p):
966        q = p + 17
967        p = q - 5
968    z3 = -8
969    inten = (porod_scale / q ** porod_exp
970                + lorentz_scale / (1 + z ** lorentz_exp))
971    return inten
972
973def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17):
974    z1 = z2 = z = abs (q - peak_pos) * lorentz_length
975    if (q > p):
976        q = p + 17
977        p = q - 5
978    elif (q == p):
979        q = p * q
980        q *= z1
981        p = z1
982    elif (q == 17):
983        q = p * q - 17
984    else:
985        q += 7
986    z3 = -8
987    inten = (porod_scale / q ** porod_exp
988                + lorentz_scale / (1 + z ** lorentz_exp))
989    return inten
990
991def print_function(f=None):
992    """
993    Print out the code for the function
994    """
995    # Include some comments to see if they get printed
996    import ast
997    import inspect
998    if f is not None:
999        tree = ast.parse(inspect.getsource(f))
1000        tree_source = to_source (tree)
1001        print(tree_source)
1002
1003def translate (functions, constants=0):
1004    sniplets = []
1005    fname = functions[1]
1006    python_file = open (fname, "r")
1007    source = python_file.read()
1008    python_file.close()
1009    tree = ast.parse (source)
1010    sniplet = to_source (tree, functions) # in the future add filename, offset, constants
1011    sniplets.append(sniplet)
1012    return ("\n".join(sniplets))
1013
1014def get_file_names ():
1015    fname_in = ""
1016    fname_out = ""
1017    if (len(sys.argv) > 1):
1018        fname_in = sys.argv[1]
1019        fname_base = os.path.splitext (fname_in)
1020        if (len (sys.argv) == 2):
1021            fname_out = str (fname_base[0]) + '.c'
1022        else:
1023            fname_out = sys.argv[2]
1024        if (len (fname_in) > 0):
1025            python_file = open (sys.argv[1], "r")
1026            if (len (fname_out) > 0):
1027                file_out = open (fname_out, "w+")
1028    return len(sys.argv), fname_in, fname_out
1029
1030if __name__ == "__main__":
1031    import os
1032    print("Parsing...using Python" + sys.version)
1033    try:
1034        fname_in = ""
1035        fname_out = ""
1036        if (len (sys.argv) == 1):
1037            print ("Usage:\npython parse01.py <infile> [<outfile>] (if omitted, output file is '<infile>.c'")
1038        else:
1039            fname_in = sys.argv[1]
1040            fname_base = os.path.splitext (fname_in)
1041            if (len (sys.argv) == 2):
1042                fname_out = str (fname_base[0]) + '.c'
1043            else:
1044                fname_out = sys.argv[2]
1045            if (len (fname_in) > 0):
1046                python_file = open (sys.argv[1], "r")
1047                if (len (fname_out) > 0):
1048                    file_out = open (fname_out, "w+")
1049                functions = ["MultAsgn", "Iq41", "Iq2"]
1050                tpls = [functions, fname_in, 0]
1051                c_txt = translate (tpls)
1052                file_out.write (c_txt)
1053                file_out.close()
1054    except Exception as excp:
1055        print ("Error:\n" + str(excp.args))
1056    print("...Done")
Note: See TracBrowser for help on using the repository browser.