Changeset 937afef in sasmodels


Ignore:
Timestamp:
Dec 14, 2017 11:47:32 AM (7 years ago)
Author:
Omer Eisenberg <omereis@…>
Children:
fb5c8c7
Parents:
2badeca
Message:

fixed bug in 'for' loop translation

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/py2c.py

    r2badeca r937afef  
    454512/07/2017, OE: Translation of integer division, '\\' in python, implemented in translate_integer_divide, called from visit_BinOp 
    464612/07/2017, OE: C variable definition handled in 'define_C_Vars' 
     47              : Python integer division, '//', translated to C in 'translate_integer_divide' 
    4748""" 
    4849import ast 
    4950import sys 
    50 import os 
    5151from ast import NodeVisitor 
    5252 
     
    8989 
    9090#def to_source(node, indent_with=' ' * 4, add_line_information=False): 
    91 def to_source(node, func_name, global_vectors): 
     91def to_source(node, func_name): 
    9292    """This function can convert a node tree back into python sourcecode. 
    9393    This is useful for debugging purposes, especially if you're dealing with 
     
    106106    of the nodes are added to the output.  This can be used to spot wrong line 
    107107    number information of statement nodes. 
    108  
    109     global_vectors arguments is a list of strings, holding the names of global 
    110     variables already declared as arrays 
    111108    """ 
    112 #    print(str(node)) 
    113 #    return 
    114109    generator = SourceGenerator(' ' * 4, False) 
    115     generator.global_vectors = global_vectors 
    116110#    generator.required_functions = func_name 
    117111    generator.visit(node) 
     
    119113#    return ''.join(generator.result) 
    120114    return ''.join(generator.c_proc) 
    121  
    122 EXTREN_DOUBLE = "extern double " 
    123  
    124 def get_c_array_value(val): 
    125     val_str = str(val).strip() 
    126     val_str = val_str.replace('[','') 
    127     val_str = val_str.replace(']','') 
    128     val_str = val_str.strip() 
    129     double_blank = "  " in val_str 
    130     while (double_blank): 
    131         val_str = val_str.replace('  ',' ') 
    132         double_blank = "  " in val_str 
    133     val_str = val_str.replace(' ',',') 
    134     return (val_str) 
    135  
    136 def write_include_const (snippets, constants): 
    137     c_globals=[] 
    138     snippets.append("#include <math.h>\n") 
    139     const_names = constants.keys() 
    140     for name in const_names: 
    141         c_globals.append(name) 
    142         val = constants[name] 
    143         var_dcl = "double "# + " = " 
    144         if (hasattr(val, "__len__")): # the value is an array 
    145             val_str = get_c_array_value(val) 
    146             var_dcl += str(name) + "[] = {" + str(val_str) + "}" 
    147         else: 
    148             var_dcl += str(name) + " = " + str(val) 
    149         var_dcl += ";\n" 
    150         snippets.append(var_dcl) 
    151              
    152     return (c_globals) 
    153 #    for c_var in constants: 
    154 #        sTmp = EXTREN_DOUBLE + str(c_var) + ";" 
    155 #        c_externs.append (EXTREN_DOUBLE + c_var + ";") 
    156 #    snippets.append (c_externs) 
    157115 
    158116def isevaluable(s): 
     
    197155        self.is_sequence = False 
    198156        self.visited_args = False 
    199         self.global_vectors = [] 
    200157 
    201158    def write_python(self, x): 
     
    321278                    self.C_Vars.append (target.id) 
    322279 
     280    def add_semi_colon (self): 
     281        semi_pos = self.current_statement.find(';') 
     282        if (semi_pos < 0): 
     283            self.write_c(';') 
     284 
    323285    def visit_Assign(self, node): 
     286        self.add_current_line () 
    324287        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7' 
    325288            if idx: 
     
    334297        self.visited_args = False 
    335298        self.visit(node.value) 
    336         self.write_c(';') 
     299        self.add_semi_colon () 
     300#        self.write_c(';') 
    337301        self.add_current_line () 
    338302        for n,item in enumerate (self.Tuples): 
     
    340304            self.write_c(' = ') 
    341305            self.visit(item) 
    342             self.write_c(';') 
     306            self.add_semi_colon () 
    343307            self.add_current_line () 
    344308        if ((self.is_sequence) and (not self.visited_args)): 
    345             for target  in node.targets: 
     309            for target in node.targets: 
    346310                if (hasattr (target, 'id')): 
    347311                    if ((target.id in self.C_Vars) and (target.id not in self.C_DclPointers)): 
     
    359323        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ') 
    360324        self.visit(node.value) 
    361         self.write_c(';') 
     325        self.add_semi_colon () 
     326#        self.write_c(';') 
    362327        self.add_current_line () 
    363328 
     
    405370    def insert_C_Vars (self, start_var): 
    406371        fLine = False 
     372        start_var = self.write_C_Pointers (start_var) 
     373        if (len(self.C_IntVars) > 0): 
     374            for var in self.C_IntVars: 
     375                if (var in self.C_Vars): 
     376                    self.C_Vars.remove(var) 
     377            s = self.listToDeclare(self.C_IntVars) 
     378            self.c_proc.insert (start_var, "    int " + s + ";\n") 
     379            fLine = True 
     380            start_var += 1 
     381             
    407382        if (len(self.C_Vars) > 0): 
    408383            s = self.listToDeclare(self.C_Vars) 
     
    410385            fLine = True 
    411386            start_var += 1 
    412         if (len(self.C_IntVars) > 0): 
    413             s = self.listToDeclare(self.C_IntVars) 
    414             self.c_proc.insert (start_var, "    int " + s + ";\n") 
    415             fLine = True 
    416             start_var += 1 
     387#        if (len(self.C_IntVars) > 0): 
     388#            s = self.listToDeclare(self.C_IntVars) 
     389#            self.c_proc.insert (start_var, "    int " + s + ";\n") 
     390#            fLine = True 
     391#            start_var += 1 
    417392        if (len (self.C_Vectors) > 0): 
    418393            s = self.listToDeclare(self.C_Vectors) 
     
    422397                self.c_proc.insert (start_var, c_dcl + "\n") 
    423398                start_var += 1 
    424         if (len (self.C_Pointers) > 0): 
    425             vars = "" 
    426             for n in range(len(self.C_Pointers)): 
    427                 if (len(vars) > 0): 
    428                     vars += ", " 
    429                 if (self.C_Pointers[n] not in self.arguments): 
    430                     vars += "*" + self.C_Pointers[n] 
    431             if (len(vars) > 0): 
    432                 c_dcl = "    double " + vars + ";" 
    433                 self.c_proc.insert (start_var, c_dcl + "\n") 
    434                 start_var += 1 
    435399        self.C_Vars.clear() 
    436400        self.C_IntVars.clear() 
    437401        self.C_Vectors.clear() 
    438402        self.C_Pointers.clear() 
    439         self.C_DclPointers.clear() 
     403        self.C_DclPointers 
    440404        if (fLine == True): 
    441405            self.c_proc.insert (start_var, "\n") 
     
    450414            self.c_proc.insert (start_var + 1, "\n") 
    451415 
    452 #    def writeInclude (self): 
    453 #        if (self.MathIncludeed == False): 
    454 #            self.add_c_line("#include <math.h>\n") 
    455 #            self.add_c_line("static double pi = 3.14159265359;\n") 
    456 #            self.MathIncludeed = True 
     416    def writeInclude (self): 
     417        if (self.MathIncludeed == False): 
     418            self.add_c_line("#include <math.h>\n") 
     419            self.add_c_line("static double pi = 3.14159265359;\n") 
     420            self.MathIncludeed = True 
    457421 
    458422    def ListToString (self, strings): 
     
    493457        self.name = node.name 
    494458#        if self.name not in self.required_functions[0]: 
    495 #            return 
     459#           return 
    496460        print("Parsing '" + self.name + "'") 
    497461        args_str = "" 
     
    499463        self.visit(node.args) 
    500464# for C 
    501 #        self.writeInclude() 
     465        self.writeInclude() 
    502466        self.getMethodSignature () 
    503467# for C 
     
    583547        return (line_number) 
    584548 
     549    def GetNodeAsString (self, node): 
     550        res = '' 
     551        if (hasattr(node,'n')): 
     552            res = str(node.n) 
     553        elif (hasattr(node,'id')): 
     554            res = node.id 
     555        return (res) 
     556 
     557    def GetForRange(self, node): 
     558        stop = "" 
     559        start = '0' 
     560        step = '1' 
     561        for_args = [] 
     562        temp_statement = self.current_statement 
     563        self.current_statement = '' 
     564        for arg in node.iter.args: 
     565            self.visit(arg) 
     566            for_args.append(self.current_statement) 
     567            self.current_statement = '' 
     568        self.current_statement = temp_statement 
     569        if (len(for_args) == 1): 
     570            stop = for_args[0] 
     571        elif (len(for_args) == 2): 
     572            start = for_args[0] 
     573            stop = for_args[1] 
     574        elif (len(for_args) == 3): 
     575            start = for_args[0] 
     576            stop = for_args[1] 
     577            start = for_args[2] 
     578        else: 
     579            raise("Ilegal for loop parameters") 
     580        return (start, stop, step) 
     581 
    585582    def visit_For(self, node): 
    586         if (len(self.current_statement) > 0): 
    587             self.add_c_line(self.current_statement) 
    588             self.current_statement = '' 
     583# node: for iterator is stored in node.target. 
     584# Iterator name is in node.target.id. 
     585        self.add_current_line() 
     586#        if (len(self.current_statement) > 0): 
     587#            self.add_c_line(self.current_statement) 
     588#            self.current_statement = '' 
    589589        fForDone = False 
     590        self.current_statement = '' 
    590591        if (hasattr(node.iter,'func')): 
    591592            if (hasattr (node.iter.func,'id')): 
    592593                if (node.iter.func.id == 'range'): 
    593                     if ('n' not in self.C_IntVars): 
    594                         self.C_IntVars.append ('n') 
    595                     self.write_c ("for (n=0 ; n < len(") 
    596                     for arg in node.iter.args: 
    597                         self.visit(arg) 
    598                     self.write_c (") ; n++) {") 
     594                    self.visit(node.target) 
     595                    iterator = self.current_statement 
     596                    self.current_statement = '' 
     597                    if (iterator not in self.C_IntVars): 
     598                        self.C_IntVars.append (iterator) 
     599                    start, stop, step = self.GetForRange(node) 
     600                    self.write_c ("for (" + iterator + "=" + str(start) + \ 
     601                                  " ; " + iterator + " < " + str(stop) + \ 
     602                                  " ; " + iterator + " += " + str(step) + ") {") 
    599603                    self.body_or_else(node) 
    600604                    self.write_c ("}") 
    601                     self.add_current_line () # just for breaking point. to be deleted.  
    602605                    fForDone = True 
    603606        if (fForDone == False): 
     
    771774                self.write_c('**') 
    772775                self.visit(node.kwargs) 
    773         self.write_c(')') 
     776        self.write_c(');') 
    774777 
    775778    def visit_Name(self, node): 
     
    784787            name = node.id 
    785788#       add variable to C_Vars if it ins't there yet, not an argument and not a number 
    786         if (name not in self.global_vectors): 
    787             if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)): 
    788                 if (self.InSubscript): 
    789                     self.C_IntVars.append (node.id) 
    790                 else: 
    791                     self.C_Vars.append (node.id) 
     789        if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)): 
     790            if (self.InSubscript): 
     791                self.C_IntVars.append (node.id) 
     792            else: 
     793                self.C_Vars.append (node.id) 
    792794 
    793795    def visit_Str(self, node): 
     
    969971        def visit(self, node): 
    970972            self.write_python(left) 
     973            self.write_c(left) 
    971974            self.visit(node.elt) 
    972975            for comprehension in node.generators: 
    973976                self.visit(comprehension) 
    974             self.write_python(right) 
     977            self.write_c(right) 
     978#            self.write_python(right) 
    975979        return visit 
    976980 
     
    10161020        self.write_c(' for ') 
    10171021        self.visit(node.target) 
    1018         self.write_python(' in ') 
     1022        self.write_C(' in ') 
     1023#        self.write_python(' in ') 
    10191024        self.visit(node.iter) 
    10201025        if node.ifs: 
     
    10771082        print(tree_source) 
    10781083 
    1079  
    1080 def write_consts (constants): 
    1081     f = False 
    1082     if (constants is not None): 
    1083         f = True 
    1084  
    1085 def translate(functions, constants): 
    1086     # type: (List[Tuple[str, str, int]], Dict[str, Any]) -> str 
    1087 #    print ("Size of functions is: " + str(len(functions))) 
    1088     if (os.path.exists ("xlate.c")): 
    1089         os.remove("xlate.c") 
    1090     snippets = [] 
    1091     global_vectors = write_include_const (snippets, constants) 
    1092     const_keys = constants.keys () 
    1093     for source, filename, offset in functions: 
    1094         tree = ast.parse(source) 
    1095         snippet = to_source(tree, filename, global_vectors)#, offset) 
    1096         snippets.append(snippet) 
    1097     if (constants is not None): 
    1098         write_consts (constants) 
    1099     c_text = "\n".join(snippets) 
    1100     try: 
    1101         translated = open ("xlate.c", "a+") 
    1102         translated.write (c_text) 
    1103         translated.close() 
    1104     except Exception as excp: 
    1105         strErr = "Error:\n" + str(excp.args) 
    1106         print(strErr) 
    1107     return (c_text) 
    1108  
    1109 def translate_from_file (functions, constants=0): 
    1110     print ("Size of functions is: " + len(functions)) 
     1084def translate (functions, constants=0): 
    11111085    sniplets = [] 
    11121086    fname = functions[1] 
     
    11621136        print ("Error:\n" + str(excp.args)) 
    11631137    print("...Done") 
    1164 #            "program": "${workspaceRoot}/ -m sascomp.generate", 
Note: See TracChangeset for help on using the changeset viewer.