Changeset 4c87de0 in sasmodels


Ignore:
Timestamp:
Jan 3, 2018 7:13:10 PM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Children:
6a37819
Parents:
15be191
Message:

improve docs; implement simple print; support True/False?; fix division; support ternery if; fix semicolon in expression

File:
1 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/py2c.py

    r15be191 r4c87de0  
    1 """ 
    2     py2c 
    3     ~~~~ 
    4  
    5     Convert simple numeric python code into C code. 
    6  
    7     The translate() function works on 
    8  
    9     Variables definition in C 
    10     ------------------------- 
    11     Defining variables within the translate function is a bit of a guess work, 
    12     using following rules: 
    13     *   By default, a variable is a 'double'. 
    14     *   Variable in a for loop is an int. 
    15     *   Variable that is references with brackets is an array of doubles. The 
    16         variable within the brackets is integer. For example, in the 
    17         reference 'var1[var2]', var1 is a double array, and var2 is an integer. 
    18     *   Assignment to an argument makes that argument an array, and the index 
    19         in that assignment is 0. 
    20         For example, the following python code:: 
    21             def func(arg1, arg2): 
    22                 arg2 = 17. 
    23         is translated to the following C code:: 
    24             double func(double arg1) 
    25             { 
    26                 arg2[0] = 17.0; 
    27             } 
    28         For example, the following python code is translated to the 
    29         following C code:: 
    30  
    31             def func(arg1, arg2):          double func(double arg1) { 
    32                 arg2 = 17.                      arg2[0] = 17.0; 
    33                                             } 
    34     *   All functions are defined as double, even if there is no 
    35         return statement. 
    36  
    37 Based on codegen.py: 
    38  
     1r""" 
     2py2c 
     3~~~~ 
     4 
     5Convert simple numeric python code into C code. 
     6 
     7This code is intended to translate direct algorithms for scientific code 
     8(mostly if statements and for loops operating on double precision values) 
     9into C code. Unlike projects like numba, cython, pypy and nuitka, the 
     10:func:`translate` function returns the corresponding C which can then be 
     11compiled with tinycc or sent to the GPU using CUDA or OpenCL. 
     12 
     13There is special handling certain constructs, such as *for i in range* and 
     14small integer powers. 
     15 
     16**TODO: make a nice list of supported constructs*** 
     17 
     18Imports are not supported, but they are at least ignored so that properly 
     19constructed code can be run via python or translated to C without change. 
     20 
     21Most other python constructs are **not** supported: 
     22* classes 
     23* builtin types (dict, set, list) 
     24* exceptions 
     25* with context 
     26* del 
     27* yield 
     28* async 
     29* list slicing 
     30* multiple return values 
     31* "is/is not", "in/not in" conditionals 
     32 
     33There is limited support for list and list comprehensions, so long as they 
     34can be represented by a fixed array whose size is known at compile time, and 
     35they are small enough to be stored on the stack. 
     36 
     37Variables definition in C 
     38------------------------- 
     39Defining variables within the translate function is a bit of a guess work, 
     40using following rules: 
     41*   By default, a variable is a 'double'. 
     42*   Variable in a for loop is an int. 
     43*   Variable that is references with brackets is an array of doubles. The 
     44    variable within the brackets is integer. For example, in the 
     45    reference 'var1[var2]', var1 is a double array, and var2 is an integer. 
     46*   Assignment to an argument makes that argument an array, and the index 
     47    in that assignment is 0. 
     48    For example, the following python code:: 
     49        def func(arg1, arg2): 
     50            arg2 = 17. 
     51    is translated to the following C code:: 
     52        double func(double arg1) 
     53        { 
     54            arg2[0] = 17.0; 
     55        } 
     56    For example, the following python code is translated to the 
     57    following C code:: 
     58 
     59        def func(arg1, arg2):          double func(double arg1) { 
     60            arg2 = 17.                      arg2[0] = 17.0; 
     61                                        } 
     62*   All functions are defined as double, even if there is no 
     63    return statement. 
     64 
     65Debugging 
     66--------- 
     67 
     68*print* is partially supported using a simple regular expression. This 
     69requires a stylized form. Be sure to use print as a function instead of 
     70the print statement. If you are including substition variables, use the 
     71% string substitution style. Include parentheses around the substitution 
     72tuple, even if there is only one item; do not include the final comma even 
     73if it is a single item (yes, it won't be a tuple, but it makes the regexp 
     74much simpler). Keep the item on a single line. Here are three forms that work:: 
     75 
     76    print("x") => printf("x\n"); 
     77    print("x %g"%(a)) => printf("x %g\n", a); 
     78    print("x %g %g %g"%(a, b, c)) => printf("x %g %g %g\n", a, b, c); 
     79 
     80You can generate *main* using the *if __name__ == "__main__":* construct. 
     81This does a simple substitution with "def main():" before translation and 
     82a substitution with "int main(int argc, double *argv[])" after translation. 
     83The result is that the content of the *if* block becomes the content of *main*. 
     84Along with the print statement, you can run and test a translation standalone 
     85using:: 
     86 
     87    python py2c.py source.py 
     88    cc source.c 
     89    ./a.out 
     90 
     91Known issues 
     92------------ 
     93The following constructs may cause problems: 
     94 
     95* implicit arrays: possible namespace collision for variable "vec#" 
     96* swap fails: "x,y = y,x" will set x==y 
     97* top-level statements: code outside a function body causes errors 
     98* line number skew: each statement should be tagged with its own #line 
     99  to avoid skew as comments are skipped and loop bodies are wrapped with 
     100  braces, etc. 
     101 
     102References 
     103---------- 
     104 
     105Based on a variant of codegen.py: 
     106 
     107    https://github.com/andreif/codegen 
    39108    :copyright: Copyright 2008 by Armin Ronacher. 
    40109    :license: BSD. 
    41110""" 
     111 
     112 
    42113""" 
    43114Update Notes 
     
    6413512/18/2017, OE: Added call to 'add_current_line()' at the beginning 
    65136                of visit_Return 
    66  
     1372018-01-03, PK: Update interface for use in sasmodels 
     1382018-01-03, PK: support "expr if cond else expr" syntax 
     1392018-01-03, PK: x//y => (int)((x)/(y)) and x/y => ((double)(x)/(double)(y)) 
     1402018-01-03, PK: True/False => true/false 
     1412018-01-03, PK: f(x) was introducing an extra semicolon 
     1422018-01-03, PK: simplistic print function, for debugging 
    67143""" 
    68144import ast 
     
    155231        self.c_vectors = [] 
    156232        self.c_constants = constants if constants is not None else {} 
     233        self.in_expr = False 
    157234        self.in_subref = False 
    158235        self.in_subscript = False 
     
    291368 
    292369    def add_semi_colon(self): 
    293         semi_pos = self.current_statement.find(';') 
    294         if semi_pos > 0.0: 
    295             self.current_statement = self.current_statement.replace(';', '') 
     370        #semi_pos = self.current_statement.find(';') 
     371        #if semi_pos >= 0: 
     372        #    self.current_statement = self.current_statement.replace(';', '') 
    296373        self.write_c(';') 
    297374 
    298375    def visit_Assign(self, node): 
    299376        self.add_current_line() 
     377        self.in_expr = True 
    300378        for idx, target in enumerate(node.targets): # multi assign, as in 'a = b = c = 7' 
    301379            if idx: 
     
    329407                                self.c_vars.remove(target.id) 
    330408        self.current_statement = '' 
     409        self.in_expr = False 
    331410 
    332411    def visit_AugAssign(self, node): 
     
    334413            if node.target.id not in self.arguments: 
    335414                self.c_vars.append(node.target.id) 
     415        self.in_expr = True 
    336416        self.visit(node.target) 
    337417        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ') 
    338418        self.visit(node.value) 
    339419        self.add_semi_colon() 
     420        self.in_expr = False 
    340421        self.add_current_line() 
    341422 
     
    357438 
    358439    def visit_Expr(self, node): 
     440        #self.in_expr = True 
    359441        self.newline(node) 
    360442        self.generic_visit(node) 
     443        #self.in_expr = False 
    361444 
    362445    def write_c_pointers(self, start_var): 
     
    474557 
    475558    def visit_If(self, node): 
     559 
    476560        self.write_c('if ') 
     561        self.in_expr = True 
    477562        self.visit(node.test) 
     563        self.in_expr = False 
    478564        self.write_c(' {') 
    479565        self.body(node.body) 
     
    488574                #self.newline() 
    489575                self.write_c('else if ') 
     576                self.in_expr = True 
    490577                self.visit(node.test) 
     578                self.in_expr = False 
    491579                self.write_c(' {') 
    492580                self.body(node.body) 
     
    541629                        self.c_int_vars.append(iterator) 
    542630                    start, stop, step = self.get_for_range(node) 
    543                     self.write_c("for(" + iterator + "=" + str(start) + 
     631                    self.write_c("for (" + iterator + "=" + str(start) + 
    544632                                 " ; " + iterator + " < " + str(stop) + 
    545633                                 " ; " + iterator + " += " + str(step) + ") {") 
     
    559647 
    560648    def visit_While(self, node): 
    561         self.unsupported(node) 
    562  
    563649        self.newline(node) 
    564650        self.write_c('while ') 
    565651        self.visit(node.test) 
    566         self.write_c(':') 
     652        self.write_c(' {') 
    567653        self.body_or_else(node) 
     654        self.write_c('}') 
     655        self.add_current_line() 
    568656 
    569657    def visit_With(self, node): 
     
    584672 
    585673    def visit_Print(self, node): 
    586         # TODO: print support would be nice, though hard to do 
    587674        self.unsupported(node) 
    588675 
     
    644731    def visit_Return(self, node): 
    645732        self.add_current_line() 
     733        self.in_expr = True 
    646734        if node.value is None: 
    647735            self.write_c('return') 
     
    651739        self.write_c(')') 
    652740        self.add_semi_colon() 
     741        self.in_expr = False 
    653742        self.add_c_line(self.current_statement) 
    654743        self.current_statement = '' 
     
    731820                self.write_c('**') 
    732821                self.visit(node.kwargs) 
    733         self.write_c(');') 
    734  
     822        self.write_c(')') 
     823        if not self.in_expr: 
     824            self.add_semi_colon() 
     825 
     826    TRANSLATE_CONSTANTS = { 
     827        'True': 'true', 
     828        'False': 'false', 
     829        'None': 'NULL',  # "None" will probably fail for other reasons 
     830        } 
    735831    def visit_Name(self, node): 
     832        translation = self.TRANSLATE_CONSTANTS.get(node.id, None) 
     833        if translation: 
     834            self.write_c(translation) 
     835            return 
    736836        self.write_c(node.id) 
    737837        if node.id in self.c_pointers and not self.in_subref: 
     
    753853 
    754854    def visit_Str(self, node): 
    755         self.write_c(repr(node.s)) 
     855        s = node.s 
     856        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n') 
     857        self.write_c('"') 
     858        self.write_c(s) 
     859        self.write_c('"') 
    756860 
    757861    def visit_Bytes(self, node): 
    758         self.write_c(repr(node.s)) 
     862        s = node.s 
     863        s = s.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n') 
     864        self.write_c('"') 
     865        self.write_c(s) 
     866        self.write_c('"') 
    759867 
    760868    def visit_Num(self, node): 
     
    842950        if is_negative_exp: 
    843951            self.write_c(")") 
    844         self.write_c(" ") 
     952        #self.write_c(" ") 
    845953 
    846954    def translate_integer_divide(self, node): 
    847         self.write_c("(int)(") 
     955        self.write_c("(int)((") 
    848956        self.visit(node.left) 
    849         self.write_c(") /(int)(") 
     957        self.write_c(")/(") 
    850958        self.visit(node.right) 
    851         self.write_c(")") 
     959        self.write_c("))") 
     960 
     961    def translate_float_divide(self, node): 
     962        self.write_c("((double)(") 
     963        self.visit(node.left) 
     964        self.write_c(")/(double)(") 
     965        self.visit(node.right) 
     966        self.write_c("))") 
    852967 
    853968    def visit_BinOp(self, node): 
     
    857972        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.FloorDiv]: 
    858973            self.translate_integer_divide(node) 
     974        elif '%s' % BINOP_SYMBOLS[type(node.op)] == BINOP_SYMBOLS[ast.Div]: 
     975            self.translate_float_divide(node) 
    859976        else: 
    860977            self.visit(node.left) 
     
    9661083 
    9671084    def visit_IfExp(self, node): 
     1085        self.write_c('((') 
     1086        self.visit(node.test) 
     1087        self.write_c(')?(') 
    9681088        self.visit(node.body) 
    969         self.write_c(' if ') 
    970         self.visit(node.test) 
    971         self.write_c(' else ') 
     1089        self.write_c('):(') 
    9721090        self.visit(node.orelse) 
     1091        self.write_c('))') 
    9731092 
    9741093    def visit_Starred(self, node): 
     
    11091228                         % ", ".join(str(node) for node in dag.keys())) 
    11101229 
     1230import re 
     1231PRINT_ARGS = re.compile(r'print[(]"(?P<template>[^"]*)" *% *[(](?P<args>[^\n]*)[)] *[)] *\n') 
     1232SUBST_ARGS = r'printf("\g<template>\\n", \g<args>)\n' 
     1233PRINT_STR = re.compile(r'print[(]"(?P<template>[^"]*)" *[)] *\n') 
     1234SUBST_STR = r'printf("\g<template>\n")' 
    11111235def translate(functions, constants=None): 
    1112     # type: (List[(str, str, int)], Dict[str, any]) -> List[str] 
     1236    # type: (Sequence[(str, str, int)], Dict[str, any]) -> List[str] 
    11131237    """ 
    1114     Convert a set of functions 
     1238    Convert a list of functions to a list of C code strings. 
     1239 
     1240    A function is given by the tuple (source, filename, line number). 
     1241 
     1242    Global constants are given in a dictionary of {name: value}.  The 
     1243    constants are used for name space resolution and type inferencing. 
     1244    Constants are not translated by this code. Instead, call 
     1245    :func:`define_constant` with name and value, and maybe block_size 
     1246    if arrays need to be padded to the next block boundary. 
     1247 
     1248    Function prototypes are not generated. Use :func:`ordered_dag` 
     1249    to list the functions in reverse order of dependency before calling 
     1250    translate. [Maybe a future revision will return the function prototypes 
     1251    so that a suitable "*.h" file can be generated. 
    11151252    """ 
    11161253    snippets = [] 
     
    11201257        line_directive = '#line %d "%s"\n'%(lineno, fname.replace('\\', '\\\\')) 
    11211258        snippets.append(line_directive) 
     1259        # Replace simple print function calls with printf statements 
     1260        source = PRINT_ARGS.sub(SUBST_ARGS, source) 
     1261        source = PRINT_STR.sub(SUBST_STR, source) 
    11221262        tree = ast.parse(source) 
    11231263        c_code = to_source(tree, constants=constants, fname=fname, lineno=lineno) 
     
    11491289            .replace(name+'.n', 'GAUSS_N') 
    11501290            .replace(name+'.z', 'GAUSS_Z') 
    1151             .replace(name+'.w', 'GAUSS_W')) 
    1152  
    1153     translation = translate([(code, fname_in, 1)]) 
     1291            .replace(name+'.w', 'GAUSS_W') 
     1292            .replace('if __name__ == "__main__"', "def main()") 
     1293    ) 
     1294 
     1295 
     1296    c_code = "".join(translate([(code, fname_in, 1)])) 
     1297    c_code = c_code.replace("double main()", "int main(int argc, char *argv[])") 
    11541298 
    11551299    with open(fname_out, "w") as file_out: 
    1156         file_out.write("".join(translation)) 
     1300        file_out.write(""" 
     1301#include <stdio.h> 
     1302#include <stdbool.h> 
     1303#include <math.h> 
     1304#define constant const 
     1305double square(double x) { return x*x; } 
     1306double cube(double x) { return x*x*x; } 
     1307double polyval(constant double *coef, double x, int N) 
     1308{ 
     1309    int i = 0; 
     1310    double ans = coef[0]; 
     1311 
     1312    while (i < N) { 
     1313        ans = ans * x + coef[i++]; 
     1314    } 
     1315 
     1316    return ans; 
     1317} 
     1318 
     1319""") 
     1320        file_out.write(c_code) 
    11571321    print("...Done") 
    11581322 
Note: See TracChangeset for help on using the changeset viewer.