Changeset 71c5f4d in sasmodels


Ignore:
Timestamp:
Dec 11, 2017 5:35:15 PM (5 years ago)
Author:
Omer Eisenberg <omereis@…>
Children:
e2719fc
Parents:
7f79cba
Message:

included constants in C. Fixed bug in C for loop

Files:
7 added
2 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/autoc.py

    r7a40b08 r71c5f4d  
    1010import numpy as np 
    1111 
    12 from . import codegen 
     12from . import py2c 
    1313from . import special 
    1414 
     
    129129 
    130130    # translate source 
    131     functions = codegen.translate( 
     131    functions = py2c.translate( 
    132132        [code[name] for name in ordered_dag(depends) if name in code], 
    133133        constants) 
     
    136136    print("\n".join(snippets)) 
    137137    #return 
    138     raise RuntimeError("not yet converted...") 
     138#    raise RuntimeError("not yet converted...") 
    139139 
    140140    # update model info 
  • sasmodels/py2c.py

    r3f9db6e r71c5f4d  
    4848import ast 
    4949import sys 
     50import os 
    5051from ast import NodeVisitor 
    5152 
     
    8889 
    8990#def to_source(node, indent_with=' ' * 4, add_line_information=False): 
    90 def to_source(node, func_name): 
     91def to_source(node, func_name, global_vectors): 
    9192    """This function can convert a node tree back into python sourcecode. 
    9293    This is useful for debugging purposes, especially if you're dealing with 
     
    105106    of the nodes are added to the output.  This can be used to spot wrong line 
    106107    number information of statement nodes. 
     108 
     109    global_vectors arguments is a list of strings, holding the names of global 
     110    variables already declared as arrays 
    107111    """ 
     112#    print(str(node)) 
     113#    return 
    108114    generator = SourceGenerator(' ' * 4, False) 
    109     generator.required_functions = func_name 
     115    generator.global_vectors = global_vectors 
     116#    generator.required_functions = func_name 
    110117    generator.visit(node) 
    111118 
    112119#    return ''.join(generator.result) 
    113120    return ''.join(generator.c_proc) 
     121 
     122EXTREN_DOUBLE = "extern double " 
     123 
     124def get_c_array_value(val): 
     125    val_str = str(val).strip() 
     126    val_str = val_str.replace('[','') 
     127    val_str = val_str.replace(']','') 
     128    val_str = val_str.strip() 
     129    double_blank = "  " in val_str 
     130    while (double_blank): 
     131        val_str = val_str.replace('  ',' ') 
     132        double_blank = "  " in val_str 
     133    val_str = val_str.replace(' ',',') 
     134    return (val_str) 
     135 
     136def write_include_const (snippets, constants): 
     137    c_globals=[] 
     138    snippets.append("#include <math.h>\n") 
     139    const_names = constants.keys() 
     140    for name in const_names: 
     141        c_globals.append(name) 
     142        val = constants[name] 
     143        var_dcl = "double "# + " = " 
     144        if (hasattr(val, "__len__")): # the value is an array 
     145            val_str = get_c_array_value(val) 
     146            var_dcl += str(name) + "[] = {" + str(val_str) + "}" 
     147        else: 
     148            var_dcl += str(name) + " = " + str(val) 
     149        var_dcl += ";\n" 
     150        snippets.append(var_dcl) 
     151             
     152    return (c_globals) 
     153#    for c_var in constants: 
     154#        sTmp = EXTREN_DOUBLE + str(c_var) + ";" 
     155#        c_externs.append (EXTREN_DOUBLE + c_var + ";") 
     156#    snippets.append (c_externs) 
    114157 
    115158def isevaluable(s): 
     
    154197        self.is_sequence = False 
    155198        self.visited_args = False 
     199        self.global_vectors = [] 
    156200 
    157201    def write_python(self, x): 
     
    342386        return (s) 
    343387 
     388    def write_C_Pointers (self, start_var): 
     389        if (len (self.C_DclPointers) > 0): 
     390            vars = "" 
     391            for c_ptr in self.C_DclPointers: 
     392                if (len(vars) > 0): 
     393                    vars += ", " 
     394                if (c_ptr not in self.arguments): 
     395                    vars += "*" + c_ptr 
     396                if (c_ptr in self.C_Vars): 
     397                    if (c_ptr in self.C_Vars): 
     398                        self.C_Vars.remove (c_ptr) 
     399            if (len(vars) > 0): 
     400                c_dcl = "    double " + vars + ";" 
     401                self.c_proc.insert (start_var, c_dcl + "\n") 
     402                start_var += 1 
     403        return start_var 
     404 
    344405    def insert_C_Vars (self, start_var): 
    345406        fLine = False 
     
    372433                self.c_proc.insert (start_var, c_dcl + "\n") 
    373434                start_var += 1 
    374         if (len (self.C_DclPointers) > 0): 
    375             vars = '' 
    376             for n in range(len(self.C_DclPointers)): 
    377                 if (len(vars) > 0): 
    378                     vars += ', ' 
    379                 vars += "*" + self.C_DclPointers[n] 
    380             if (len(vars) > 0): 
    381                 c_dcl = "    double " + vars + ";" 
    382                 self.c_proc.insert (start_var, c_dcl + "\n") 
    383                 start_var += 1 
    384435        self.C_Vars.clear() 
    385436        self.C_IntVars.clear() 
    386437        self.C_Vectors.clear() 
    387438        self.C_Pointers.clear() 
     439        self.C_DclPointers.clear() 
    388440        if (fLine == True): 
    389441            self.c_proc.insert (start_var, "\n") 
     
    398450            self.c_proc.insert (start_var + 1, "\n") 
    399451 
    400     def writeInclude (self): 
    401         if (self.MathIncludeed == False): 
    402             self.add_c_line("#include <math.h>\n") 
    403             self.add_c_line("static double pi = 3.14159265359;\n") 
    404             self.MathIncludeed = True 
     452#    def writeInclude (self): 
     453#        if (self.MathIncludeed == False): 
     454#            self.add_c_line("#include <math.h>\n") 
     455#            self.add_c_line("static double pi = 3.14159265359;\n") 
     456#            self.MathIncludeed = True 
    405457 
    406458    def ListToString (self, strings): 
     
    431483                args_str += ", " 
    432484        self.strMethodSignature = 'double ' + self.name + ' (' + args_str + ")" 
    433         if (self.signature_line > 0): 
     485        if (self.signature_line >= 0): 
    434486            self.c_proc.insert (self.signature_line, self.strMethodSignature) 
    435487 
     
    440492        self.arguments = [] 
    441493        self.name = node.name 
    442         if self.name not in self.required_functions[0]: 
    443             return 
     494#        if self.name not in self.required_functions[0]: 
     495#            return 
    444496        print("Parsing '" + self.name + "'") 
    445497        args_str = "" 
     
    447499        self.visit(node.args) 
    448500# for C 
    449         self.writeInclude() 
     501#        self.writeInclude() 
    450502        self.getMethodSignature () 
    451503# for C 
     
    547599                    self.body_or_else(node) 
    548600                    self.write_c ("}") 
     601                    self.add_current_line () # just for breaking point. to be deleted.  
    549602                    fForDone = True 
    550603        if (fForDone == False): 
     
    731784            name = node.id 
    732785#       add variable to C_Vars if it ins't there yet, not an argument and not a number 
    733         if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)): 
    734             if (self.InSubscript): 
    735                 self.C_IntVars.append (node.id) 
    736             else: 
    737                 self.C_Vars.append (node.id) 
     786        if (name not in self.global_vectors): 
     787            if ((name not in self.C_Functions) and (name not in self.C_Vars) and (name not in self.C_IntVars) and (name not in self.arguments) and (name.isnumeric () == False)): 
     788                if (self.InSubscript): 
     789                    self.C_IntVars.append (node.id) 
     790                else: 
     791                    self.C_Vars.append (node.id) 
    738792 
    739793    def visit_Str(self, node): 
     
    10231077        print(tree_source) 
    10241078 
    1025 def translate (functions, constants=0): 
     1079 
     1080def write_consts (constants): 
     1081    f = False 
     1082    if (constants is not None): 
     1083        f = True 
     1084 
     1085def translate(functions, constants): 
     1086    # type: (List[Tuple[str, str, int]], Dict[str, Any]) -> str 
     1087#    print ("Size of functions is: " + str(len(functions))) 
     1088    if (os.path.exists ("xlate.c")): 
     1089        os.remove("xlate.c") 
     1090    snippets = [] 
     1091    global_vectors = write_include_const (snippets, constants) 
     1092    const_keys = constants.keys () 
     1093    for source, filename, offset in functions: 
     1094        tree = ast.parse(source) 
     1095        snippet = to_source(tree, filename, global_vectors)#, offset) 
     1096        snippets.append(snippet) 
     1097    if (constants is not None): 
     1098        write_consts (constants) 
     1099    c_text = "\n".join(snippets) 
     1100    try: 
     1101        translated = open ("xlate.c", "a+") 
     1102        translated.write (c_text) 
     1103        translated.close() 
     1104    except Exception as excp: 
     1105        strErr = "Error:\n" + str(excp.args) 
     1106        print(strErr) 
     1107    return (c_text) 
     1108 
     1109def translate_from_file (functions, constants=0): 
     1110    print ("Size of functions is: " + len(functions)) 
    10261111    sniplets = [] 
    10271112    fname = functions[1] 
Note: See TracChangeset for help on using the changeset viewer.