Changes in sasmodels/py2c.py [ddfdb16:7b1dcf9] in sasmodels


Ignore:
File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/py2c.py

    rddfdb16 r7b1dcf9  
    8181BOOLOP_SYMBOLS = {} 
    8282BOOLOP_SYMBOLS[ast.And] = '&&' 
    83 BOOLOP_SYMBOLS[ast.Or]  = '||' 
     83BOOLOP_SYMBOLS[ast.Or] = '||' 
    8484 
    8585CMPOP_SYMBOLS = {} 
    86 CMPOP_SYMBOLS[ast.Eq]    = '==' 
     86CMPOP_SYMBOLS[ast.Eq] = '==' 
    8787CMPOP_SYMBOLS[ast.NotEq] = '!=' 
    8888CMPOP_SYMBOLS[ast.Lt] = '<' 
     
    145145        self.new_lines = 0 
    146146        self.c_proc = [] 
    147 # for C 
     147        # for C 
    148148        self.signature_line = 0 
    149149        self.arguments = [] 
     
    181181    def add_c_line(self, x): 
    182182        string = '' 
    183         for i in range(self.indentation): 
     183        for _ in range(self.indentation): 
    184184            string += ("    ") 
    185185        string += str(x) 
     
    188188 
    189189    def add_current_line(self): 
    190         if(len(self.current_statement) > 0): 
     190        if self.current_statement: 
    191191            self.add_c_line(self.current_statement) 
    192192            self.current_statement = '' 
    193193 
    194194    def AddUniqueVar(self, new_var): 
    195         if((new_var not in self.C_Vars)): 
     195        if new_var not in self.C_Vars: 
    196196            self.C_Vars.append(str(new_var)) 
    197197 
     
    210210            self.write_c('# line: %s' % node.lineno) 
    211211            self.new_lines = 1 
    212         if(len(self.current_statement)): 
     212        if self.current_statement: 
    213213            self.Statements.append(self.current_statement) 
    214214            self.current_statement = '' 
    215215 
    216216    def body(self, statements): 
    217         if(len(self.current_statement)): 
     217        if self.current_statement: 
    218218            self.add_current_line() 
    219219        self.new_line = True 
    220220        self.indentation += 1 
    221221        for stmt in statements: 
    222             target_name = '' 
    223             if(hasattr(stmt, 'targets')): 
    224                 if(hasattr(stmt.targets[0], 'id')): 
    225                     target_name = stmt.targets[0].id # target name needed for debug only 
     222            #if hasattr(stmt, 'targets') and hasattr(stmt.targets[0], 'id'): 
     223            #    target_name = stmt.targets[0].id # target name needed for debug only 
    226224            self.visit(stmt) 
    227225        self.add_current_line() # just for breaking point. to be deleted. 
     
    242240            else: 
    243241                want_comma.append(True) 
    244 # for C 
     242 
     243        # for C 
    245244        for arg in node.args: 
    246             self.arguments.append(arg.arg) 
     245            # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg 
     246            try: 
     247                arg_name = arg.arg 
     248            except AttributeError: 
     249                arg_name = arg.id 
     250            self.arguments.append(arg_name) 
    247251 
    248252        padding = [None] *(len(node.args) - len(node.defaults)) 
    249253        for arg, default in zip(node.args, padding + node.defaults): 
    250254            if default is not None: 
    251                 self.warnings.append("Default Parameter unknown to C") 
    252                 w_str = "Default Parameters are unknown to C: '" + arg.arg + \ 
    253                         " = " + str(default.n) + "'" 
     255                # CRUFT: 2.7 uses arg.id, 3.x uses arg.arg 
     256                try: 
     257                    arg_name = arg.arg 
     258                except AttributeError: 
     259                    arg_name = arg.id 
     260                w_str = ("Default Parameters are unknown to C: '%s = %s" 
     261                         % arg_name, str(default.n)) 
    254262                self.warnings.append(w_str) 
    255263 
     
    271279 
    272280    def define_C_Vars(self, target): 
    273         if(hasattr(target, 'id')): 
    274 # a variable is considered an array if it apears in the agrument list 
    275 # and being assigned to. For example, the variable p in the following 
    276 # sniplet is a pointer, while q is not 
    277 # def somefunc(p, q): 
    278 #  p = q + 1 
    279 #  return 
    280 # 
    281             if(target.id not in self.C_Vars): 
    282                 if(target.id in self.arguments): 
     281        if hasattr(target, 'id'): 
     282        # a variable is considered an array if it apears in the agrument list 
     283        # and being assigned to. For example, the variable p in the following 
     284        # sniplet is a pointer, while q is not 
     285        # def somefunc(p, q): 
     286        #  p = q + 1 
     287        #  return 
     288        # 
     289            if target.id not in self.C_Vars: 
     290                if target.id in self.arguments: 
    283291                    idx = self.arguments.index(target.id) 
    284292                    new_target = self.arguments[idx] + "[0]" 
    285                     if(new_target not in self.C_Pointers): 
     293                    if new_target not in self.C_Pointers: 
    286294                        target.id = new_target 
    287295                        self.C_Pointers.append(self.arguments[idx]) 
     
    291299    def add_semi_colon(self): 
    292300        semi_pos = self.current_statement.find(';') 
    293         if(semi_pos > 0.0): 
    294             self.current_statement = self.current_statement.replace(';','') 
     301        if semi_pos > 0.0: 
     302            self.current_statement = self.current_statement.replace(';', '') 
    295303        self.write_c(';') 
    296304 
     
    302310            self.define_C_Vars(target) 
    303311            self.visit(target) 
    304         if(len(self.Tuples) > 0): 
     312        if self.Tuples: 
    305313            tplTargets = list(self.Tuples) 
    306             self.Tuples.clear() 
     314            del self.Tuples[:] 
    307315        self.write_c(' = ') 
    308316        self.is_sequence = False 
     
    317325            self.add_semi_colon() 
    318326            self.add_current_line() 
    319         if((self.is_sequence) and (not self.visited_args)): 
     327        if self.is_sequence and not self.visited_args: 
    320328            for target in node.targets: 
    321                 if(hasattr(target, 'id')): 
    322                     if((target.id in self.C_Vars) and(target.id not in self.C_DclPointers)): 
    323                         if(target.id not in self.C_DclPointers): 
     329                if hasattr(target, 'id'): 
     330                    if target.id in self.C_Vars and target.id not in self.C_DclPointers: 
     331                        if target.id not in self.C_DclPointers: 
    324332                            self.C_DclPointers.append(target.id) 
    325                             if(target.id in self.C_Vars): 
     333                            if target.id in self.C_Vars: 
    326334                                self.C_Vars.remove(target.id) 
    327335        self.current_statement = '' 
    328336 
    329337    def visit_AugAssign(self, node): 
    330         if(node.target.id not in self.C_Vars): 
    331             if(node.target.id not in self.arguments): 
     338        if node.target.id not in self.C_Vars: 
     339            if node.target.id not in self.arguments: 
    332340                self.C_Vars.append(node.target.id) 
    333341        self.visit(node.target) 
     
    355363        self.generic_visit(node) 
    356364 
    357     def listToDeclare(self, Vars): 
    358         s = '' 
    359         if(len(Vars) > 0): 
    360             s = ",".join(Vars) 
    361         return(s) 
     365    def listToDeclare(self, vars): 
     366        return ", ".join(vars) 
    362367 
    363368    def write_C_Pointers(self, start_var): 
    364         if(len(self.C_DclPointers) > 0): 
    365             vars = "" 
     369        if self.C_DclPointers: 
     370            var_list = [] 
    366371            for c_ptr in self.C_DclPointers: 
    367372                if(len(vars) > 0): 
    368373                    vars += ", " 
    369                 if(c_ptr not in self.arguments): 
    370                     vars += "*" + c_ptr 
    371                 if(c_ptr in self.C_Vars): 
    372                     if(c_ptr in self.C_Vars): 
    373                         self.C_Vars.remove(c_ptr) 
    374             if(len(vars) > 0): 
    375                 c_dcl = "    double " + vars + ";" 
    376                 self.c_proc.insert(start_var, c_dcl + "\n") 
     374                if c_ptr not in self.arguments: 
     375                    var_list.append("*" + c_ptr) 
     376                if c_ptr in self.C_Vars: 
     377                    self.C_Vars.remove(c_ptr) 
     378            if var_list: 
     379                c_dcl = "    double " + ", ".join(var_list) + ";\n" 
     380                self.c_proc.insert(start_var, c_dcl) 
    377381                start_var += 1 
    378382        return start_var 
     
    381385        fLine = False 
    382386        start_var = self.write_C_Pointers(start_var) 
    383         if(len(self.C_IntVars) > 0): 
     387        if self.C_IntVars: 
    384388            for var in self.C_IntVars: 
    385                 if(var in self.C_Vars): 
     389                if var in self.C_Vars: 
    386390                    self.C_Vars.remove(var) 
    387391            s = self.listToDeclare(self.C_IntVars) 
     
    390394            start_var += 1 
    391395 
    392         if(len(self.C_Vars) > 0): 
     396        if self.C_Vars: 
    393397            s = self.listToDeclare(self.C_Vars) 
    394398            self.c_proc.insert(start_var, "    double " + s + ";\n") 
    395399            fLine = True 
    396400            start_var += 1 
    397         if(len(self.C_Vectors) > 0): 
     401 
     402        if self.C_Vectors: 
    398403            s = self.listToDeclare(self.C_Vectors) 
    399404            for n in range(len(self.C_Vectors)): 
     
    402407                self.c_proc.insert(start_var, c_dcl + "\n") 
    403408                start_var += 1 
    404         self.C_Vars.clear() 
    405         self.C_IntVars.clear() 
    406         self.C_Vectors.clear() 
    407         self.C_Pointers.clear() 
     409 
     410        del self.C_Vars[:] 
     411        del self.C_IntVars[:] 
     412        del self.C_Vectors[:] 
     413        del self.C_Pointers[:] 
    408414        self.C_DclPointers 
    409         if(fLine == True): 
     415        if fLine: 
    410416            self.c_proc.insert(start_var, "\n") 
    411         return 
    412         s = '' 
    413         for n in range(len(self.C_Vars)): 
    414             s += str(self.C_Vars[n]) 
    415             if n < len(self.C_Vars) - 1: 
    416                 s += ", " 
    417         if(len(s) > 0): 
    418             self.c_proc.insert(start_var, "    double " + s + ";\n") 
    419             self.c_proc.insert(start_var + 1, "\n") 
    420  
    421     def writeInclude(self): 
    422         if(self.MathIncludeed == False): 
    423             self.add_c_line("#include <math.h>\n") 
    424             self.add_c_line("static double pi = 3.14159265359;\n") 
    425             self.MathIncludeed = True 
    426  
    427     def ListToString(self, strings): 
    428         s = '' 
    429         for n in range(len(strings)): 
    430             s += strings[n] 
    431             if(n < (len(strings) - 1)): 
    432                 s += ", " 
    433         return(s) 
    434  
    435     def getMethodSignature(self): 
    436         args_str = '' 
    437         for n in range(len(self.arguments)): 
    438             args_str += "double " + self.arguments[n] 
    439             if(n < (len(self.arguments) - 1)): 
    440                 args_str += ", " 
    441         return(args_str) 
    442417 
    443418    def InsertSignature(self): 
    444         args_str = '' 
    445         for n in range(len(self.arguments)): 
    446             args_str += "double " + self.arguments[n] 
    447             if(self.arguments[n] in self.C_Pointers): 
    448                 args_str += "[]" 
    449             if(n < (len(self.arguments) - 1)): 
    450                 args_str += ", " 
     419        arg_decls = [] 
     420        for arg in self.arguments: 
     421            decl = "double " + arg 
     422            if arg in self.C_Pointers: 
     423                decl += "[]" 
     424            arg_decls.append(decl) 
     425        args_str = ", ".join(arg_decls) 
    451426        self.strMethodSignature = 'double ' + self.name + '(' + args_str + ")" 
    452         if(self.signature_line >= 0): 
     427        if self.signature_line >= 0: 
    453428            self.c_proc.insert(self.signature_line, self.strMethodSignature) 
    454429 
     
    459434        self.arguments = [] 
    460435        self.name = node.name 
    461         print("Parsing '" + self.name + "'") 
    462         args_str = "" 
     436        #if self.name not in self.required_functions[0]: 
     437        #   return 
     438        #print("Parsing '" + self.name + "'") 
    463439 
    464440        self.visit(node.args) 
    465         self.getMethodSignature() 
     441        # for C 
    466442        self.signature_line = len(self.c_proc) 
     443        #self.add_c_line(self.strMethodSignature) 
    467444        self.add_c_line("\n{") 
    468445        start_vars = len(self.c_proc) + 1 
     
    489466            paren_or_comma() 
    490467            self.visit(base) 
    491         # XXX: the if here is used to keep this module compatible 
    492         #      with python 2.6. 
     468        # CRUFT: python 2.6 does not have "keywords" attribute 
    493469        if hasattr(node, 'keywords'): 
    494470            for keyword in node.keywords: 
     
    517493            if len(else_) == 0: 
    518494                break 
    519 #            elif hasattr(else_, 'orelse'): 
     495            #elif hasattr(else_, 'orelse'): 
    520496            elif len(else_) == 1 and isinstance(else_[0], ast.If): 
    521497                node = else_[0] 
    522 #                self.newline() 
     498                #self.newline() 
    523499                self.write_c('else if ') 
    524500                self.visit(node.test) 
     
    527503                self.add_current_line() 
    528504                self.add_c_line('}') 
    529 #                break 
     505                #break 
    530506            else: 
    531507                self.newline() 
     
    537513    def getNodeLineNo(self, node): 
    538514        line_number = -1 
    539         if(hasattr(node,'value')): 
     515        if hasattr(node, 'value'): 
    540516            line_number = node.value.lineno 
    541517        elif hasattr(node, 'iter'): 
    542518            if hasattr(node.iter, 'lineno'): 
    543519                line_number = node.iter.lineno 
    544         return(line_number) 
     520        return line_number 
    545521 
    546522    def GetNodeAsString(self, node): 
    547523        res = '' 
    548         if(hasattr(node, 'n')): 
     524        if hasattr(node, 'n'): 
    549525            res = str(node.n) 
    550         elif(hasattr(node, 'id')): 
     526        elif hasattr(node, 'id'): 
    551527            res = node.id 
    552         return(res) 
     528        return res 
    553529 
    554530    def GetForRange(self, node): 
     
    564540            self.current_statement = '' 
    565541        self.current_statement = temp_statement 
    566         if(len(for_args) == 1): 
     542        if len(for_args) == 1: 
    567543            stop = for_args[0] 
    568         elif(len(for_args) == 2): 
     544        elif len(for_args) == 2: 
    569545            start = for_args[0] 
    570546            stop = for_args[1] 
    571         elif(len(for_args) == 3): 
     547        elif len(for_args) == 3: 
    572548            start = for_args[0] 
    573549            stop = for_args[1] 
     
    575551        else: 
    576552            raise("Ilegal for loop parameters") 
    577         return(start, stop, step) 
     553        return start, stop, step 
    578554 
    579555    def visit_For(self, node): 
    580 # node: for iterator is stored in node.target. 
    581 # Iterator name is in node.target.id. 
     556        # node: for iterator is stored in node.target. 
     557        # Iterator name is in node.target.id. 
    582558        self.add_current_line() 
    583559        fForDone = False 
    584560        self.current_statement = '' 
    585         if(hasattr(node.iter, 'func')): 
    586             if(hasattr(node.iter.func, 'id')): 
    587                 if(node.iter.func.id == 'range'): 
     561        if hasattr(node.iter, 'func'): 
     562            if hasattr(node.iter.func, 'id'): 
     563                if node.iter.func.id == 'range': 
    588564                    self.visit(node.target) 
    589565                    iterator = self.current_statement 
    590566                    self.current_statement = '' 
    591                     if(iterator not in self.C_IntVars): 
     567                    if iterator not in self.C_IntVars: 
    592568                        self.C_IntVars.append(iterator) 
    593569                    start, stop, step = self.GetForRange(node) 
    594                     self.write_c("for(" + iterator + "=" + str(start) + \ 
    595                                   " ; " + iterator + " < " + str(stop) + \ 
    596                                   " ; " + iterator + " += " + str(step) + ") {") 
     570                    self.write_c("for(" + iterator + "=" + str(start) + 
     571                                 " ; " + iterator + " < " + str(stop) + 
     572                                 " ; " + iterator + " += " + str(step) + ") {") 
    597573                    self.body_or_else(node) 
    598574                    self.write_c("}") 
    599575                    fForDone = True 
    600         if(fForDone == False): 
     576        if not fForDone: 
    601577            line_number = self.getNodeLineNo(node) 
    602578            self.current_statement = '' 
     
    632608 
    633609    def visit_Print(self, node): 
    634 # XXX: python 2.6 only 
     610        # CRUFT: python 2.6 only 
    635611        self.newline(node) 
    636612        self.write_c('print ') 
     
    700676 
    701677    def visit_Raise(self, node): 
    702         # XXX: Python 2.6 / 3.0 compatibility 
     678        # CRUFT: Python 2.6 / 3.0 compatibility 
    703679        self.newline(node) 
    704680        self.write_python('raise') 
     
    734710            else: 
    735711                want_comma.append(True) 
    736         if(hasattr(node.func, 'id')): 
    737             if(node.func.id not in self.C_Functions): 
     712        if hasattr(node.func, 'id'): 
     713            if node.func.id not in self.C_Functions: 
    738714                self.C_Functions.append(node.func.id) 
    739             if(node.func.id == 'abs'): 
     715            if node.func.id == 'abs': 
    740716                self.write_c("fabs ") 
    741             elif(node.func.id == 'int'): 
     717            elif node.func.id == 'int': 
    742718                self.write_c('(int) ') 
    743             elif(node.func.id == "SINCOS"): 
     719            elif node.func.id == "SINCOS": 
    744720                self.WriteSincos(node) 
    745721                return 
     
    772748    def visit_Name(self, node): 
    773749        self.write_c(node.id) 
    774         if((node.id in self.C_Pointers) and(not self.SubRef)): 
     750        if node.id in self.C_Pointers and not self.SubRef: 
    775751            self.write_c("[0]") 
    776752        name = "" 
    777753        sub = node.id.find("[") 
    778         if(sub > 0): 
     754        if sub > 0: 
    779755            name = node.id[0:sub].strip() 
    780756        else: 
    781757            name = node.id 
    782 #      add variable to C_Vars if it ins't there yet, not an argument and not a number 
    783         if ((name not in self.C_Functions) and (name not in self.C_Vars) and \ 
    784             (name not in self.C_IntVars) and (name not in self.arguments) and \ 
    785             (name not in self.C_Constants) and (name.isnumeric() == False)): 
    786             if(self.InSubscript): 
     758        # add variable to C_Vars if it ins't there yet, not an argument and not a number 
     759        if (name not in self.C_Functions and name not in self.C_Vars and 
     760                name not in self.C_IntVars and name not in self.arguments and 
     761                name not in self.C_Constants and not name.isdigit()): 
     762            if self.InSubscript: 
    787763                self.C_IntVars.append(node.id) 
    788764            else: 
     
    810786            s = "" 
    811787            for idx, item in enumerate(node.elts): 
    812                 if((idx > 0) and(len(s) > 0)): 
     788                if idx > 0 and s: 
    813789                    s += ', ' 
    814                 if(hasattr(item, 'id')): 
     790                if hasattr(item, 'id'): 
    815791                    s += item.id 
    816                 elif(hasattr(item, 'n')): 
     792                elif hasattr(item, 'n'): 
    817793                    s += str(item.n) 
    818             if(len(s) > 0): 
     794            if s: 
    819795                self.C_Vectors.append(s) 
    820796                vec_name = "vec"  + str(len(self.C_Vectors)) 
    821797                self.write_c(vec_name) 
    822                 vec_name += "#" 
    823798        return visit 
    824799 
     
    840815        function_name = '' 
    841816        is_negative_exp = False 
    842         if(isevaluable(str(self.current_statement))): 
     817        if isevaluable(str(self.current_statement)): 
    843818            exponent = eval(string) 
    844819            is_negative_exp = exponent < 0 
    845820            abs_exponent = abs(exponent) 
    846             if(abs_exponent == 2): 
     821            if abs_exponent == 2: 
    847822                function_name = "square" 
    848             elif(abs_exponent == 3): 
     823            elif abs_exponent == 3: 
    849824                function_name = "cube" 
    850             elif(abs_exponent == 0.5): 
     825            elif abs_exponent == 0.5: 
    851826                function_name = "sqrt" 
    852             elif(abs_exponent == 1.0/3.0): 
     827            elif abs_exponent == 1.0/3.0: 
    853828                function_name = "cbrt" 
    854         if(function_name == ''): 
     829        if function_name == '': 
    855830            function_name = "pow" 
    856831        return function_name, is_negative_exp 
    857832 
    858833    def translate_power(self, node): 
    859 # get exponent by visiting the right hand argument. 
     834        # get exponent by visiting the right hand argument. 
    860835        function_name = "pow" 
    861836        temp_statement = self.current_statement 
    862 # 'visit' functions write the results to the 'current_statement' class memnber 
    863 # Here, a temporary variable, 'temp_statement', is used, that enables the 
    864 # use of the 'visit' function 
     837        # 'visit' functions write the results to the 'current_statement' class memnber 
     838        # Here, a temporary variable, 'temp_statement', is used, that enables the 
     839        # use of the 'visit' function 
    865840        self.current_statement = '' 
    866841        self.visit(node.right) 
     
    868843        function_name, is_negative_exp = self.get_special_power(self.current_statement) 
    869844        self.current_statement = temp_statement 
    870         if(is_negative_exp): 
     845        if is_negative_exp: 
    871846            self.write_c("1.0 /(") 
    872847        self.write_c(function_name + "(") 
    873848        self.visit(node.left) 
    874         if(function_name == "pow"): 
     849        if function_name == "pow": 
    875850            self.write_c(", ") 
    876851            self.visit(node.right) 
    877852        self.write_c(")") 
    878         if(is_negative_exp): 
     853        if is_negative_exp: 
    879854            self.write_c(")") 
    880855        self.write_c(" ") 
     
    889864    def visit_BinOp(self, node): 
    890865        self.write_c("(") 
    891         if('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]): 
     866        if '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Pow]: 
    892867            self.translate_power(node) 
    893         elif('%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]): 
     868        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]: 
    894869            self.translate_integer_divide(node) 
    895870        else: 
     
    899874        self.write_c(")") 
    900875 
    901 #      for C 
     876    # for C 
    902877    def visit_BoolOp(self, node): 
    903878        self.write_c('(') 
     
    926901 
    927902    def visit_Subscript(self, node): 
    928         if (node.value.id not in self.C_Constants): 
    929             if(node.value.id not in self.C_Pointers): 
     903        if node.value.id not in self.C_Constants: 
     904            if node.value.id not in self.C_Pointers: 
    930905                self.C_Pointers.append(node.value.id) 
    931906        self.SubRef = True 
     
    976951                self.visit(comprehension) 
    977952            self.write_c(right) 
    978 #            self.write_python(right) 
     953            #self.write_python(right) 
    979954        return visit 
    980955 
     
    1005980 
    1006981    def visit_Repr(self, node): 
    1007         # XXX: python 2.6 only 
     982        # CRUFT: python 2.6 only 
    1008983        self.write_c('`') 
    1009984        self.visit(node.value) 
     
    1021996        self.visit(node.target) 
    1022997        self.write_C(' in ') 
    1023 #        self.write_python(' in ') 
     998        #self.write_python(' in ') 
    1024999        self.visit(node.iter) 
    10251000        if node.ifs: 
     
    10431018        print(tree_source) 
    10441019 
    1045 def add_constants(sniplets, c_constants): 
    1046     sniplets.append("#include <math.h>") 
    1047     sniplets.append("") 
    1048     vars = c_constants.keys() 
    1049     for c_var in vars: 
    1050         c_values = c_constants[c_var] 
    1051         if isinstance(c_values, (int, float)): 
    1052             parts = ["double ", c_var, " = ", "%.15g"%c_values, ";"] 
    1053         else: 
    1054             elements = ["%.15g"%v for v in c_values] 
    1055             parts = ["double ", c_var, "[]", " = ", "{\n   ", ", ".join(elements), "\n};"] 
    1056         sniplets.append("".join(parts)) 
    1057  
    10581020def translate(functions, constants=0): 
    1059     sniplets = [] 
    1060     add_constants (sniplets, constants) 
    1061     for source,fname,line_no in functions: 
    1062         line_directive = '#line %d "%s"' %(line_no,fname) 
    1063         line_directive = line_directive.replace('\\','\\\\') 
    1064 #        sniplets.append(line_directive) 
     1021    snippets = [] 
     1022    #snippets.append("#include <math.h>") 
     1023    #snippets.append("") 
     1024    for source, fname, line_no in functions: 
     1025        line_directive = '#line %d "%s"'%(line_no, fname.replace('\\', '\\\\')) 
     1026        snippets.append(line_directive) 
    10651027        tree = ast.parse(source) 
    1066         sniplet = to_source(tree, functions, constants) # in the future add filename, offset, constants 
    1067         sniplets.append(sniplet) 
    1068     c_code = "\n".join(sniplets) 
    1069     f_out = open ("xlate.c", "w+") 
    1070     f_out.write (c_code) 
    1071     f_out.close() 
    1072     return("\n".join(sniplets)) 
    1073  
    1074 def get_file_names(): 
    1075     fname_in = "" 
    1076     fname_out = "" 
    1077     if(len(sys.argv) > 1): 
    1078         fname_in = sys.argv[1] 
    1079         fname_base = os.path.splitext(fname_in) 
    1080         if(len(sys.argv) == 2): 
    1081             fname_out = str(fname_base[0]) + '.c' 
    1082         else: 
    1083             fname_out = sys.argv[2] 
    1084         if(len(fname_in) > 0): 
    1085             python_file = open(sys.argv[1], "r") 
    1086             if(len(fname_out) > 0): 
    1087                 file_out = open(fname_out, "w+") 
    1088     return len(sys.argv), fname_in, fname_out 
    1089  
    1090 if __name__ == "__main__": 
     1028        # in the future add filename, offset, constants 
     1029        c_code = to_source(tree, functions, constants) 
     1030        snippets.append(c_code) 
     1031    return snippets 
     1032 
     1033def main(): 
    10911034    import os 
    10921035    print("Parsing...using Python" + sys.version) 
    1093     try: 
    1094         fname_in = "" 
    1095         fname_out = "" 
    1096         if(len(sys.argv) == 1): 
    1097             print("Usage:\npython parse01.py <infile> [<outfile>](if omitted, output file is '<infile>.c'") 
    1098         else: 
    1099             fname_in = sys.argv[1] 
    1100             fname_base = os.path.splitext(fname_in) 
    1101             if(len(sys.argv) == 2): 
    1102                 fname_out = str(fname_base[0]) + '.c' 
    1103             else: 
    1104                 fname_out = sys.argv[2] 
    1105             if(len(fname_in) > 0): 
    1106                 python_file = open(sys.argv[1], "r") 
    1107                 if(len(fname_out) > 0): 
    1108                     file_out = open(fname_out, "w+") 
    1109                 functions = ["MultAsgn", "Iq41", "Iq2"] 
    1110                 tpls = [functions, fname_in, 0] 
    1111                 c_txt = translate(tpls) 
    1112                 file_out.write(c_txt) 
    1113                 file_out.close() 
    1114     except Exception as excp: 
    1115         print("Error:\n" + str(excp.args)) 
     1036    if len(sys.argv) == 1: 
     1037        print("""\ 
     1038Usage: python py2c.py <infile> [<outfile>] 
     1039 
     1040if outfile is omitted, output file is '<infile>.c' 
     1041""") 
     1042        return 
     1043 
     1044    fname_in = sys.argv[1] 
     1045    if len(sys.argv) == 2: 
     1046        fname_base = os.path.splitext(fname_in)[0] 
     1047        fname_out = str(fname_base) + '.c' 
     1048    else: 
     1049        fname_out = sys.argv[2] 
     1050 
     1051    with open(fname_in, "r") as python_file: 
     1052        code = python_file.read() 
     1053 
     1054    translation = translate([code, fname_in, 1])[0] 
     1055 
     1056    with open(fname_out, "w") as file_out: 
     1057        file_out.write(translation) 
    11161058    print("...Done") 
     1059 
     1060if __name__ == "__main__": 
     1061    main() 
Note: See TracChangeset for help on using the changeset viewer.