Changeset c01ed3e in sasmodels


Ignore:
Timestamp:
Dec 22, 2017 4:48:12 PM (6 years ago)
Author:
Paul Kienzle <pkienzle@…>
Children:
15be191
Parents:
2694cb8
git-author:
Paul Kienzle <pkienzle@…> (12/22/17 16:46:46)
git-committer:
Paul Kienzle <pkienzle@…> (12/22/17 16:48:12)
Message:

code cleanup for py2c converter

Location:
sasmodels
Files:
1 deleted
4 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/autoc.py

    r67cc0ff rc01ed3e  
    44from __future__ import print_function 
    55 
    6 import ast 
    76import inspect 
    8 from functools import reduce 
    97 
    108import numpy as np 
     
    9694                # Claim all constants are declared on line 1 
    9795                snippets.append('#line 1 "%s"'%escaped_filename) 
    98                 snippets.append(define_constant(name, obj)) 
     96                snippets.append(py2c.define_constant(name, obj)) 
    9997            elif isinstance(obj, special.Gauss): 
    10098                for var, value in zip(("N", "Z", "W"), (obj.n, obj.z, obj.w)): 
     
    102100                    constants[var] = value 
    103101                    snippets.append('#line 1 "%s"'%escaped_filename) 
    104                     snippets.append(define_constant(var, value)) 
     102                    snippets.append(py2c.define_constant(var, value)) 
    105103                #libs.append('lib/gauss%d.c'%obj.n) 
    106104                source = (source.replace(name+'.n', 'GAUSS_N') 
     
    121119 
    122120    # translate source 
    123     ordered_code = [code[name] for name in ordered_dag(depends) if name in code] 
     121    ordered_code = [code[name] for name in py2c.ordered_dag(depends) if name in code] 
    124122    functions = py2c.translate(ordered_code, constants) 
    125123    snippets.extend(functions) 
     
    129127    info.c_code = "\n".join(snippets) 
    130128    info.Iq = info.Iqac = info.Iqabc = info.Iqxy = info.form_volume = None 
    131  
    132 def define_constant(name, value): 
    133     if isinstance(value, int): 
    134         parts = ["int ", name, " = ", "%d"%value, ";"] 
    135     elif isinstance(value, float): 
    136         parts = ["double ", name, " = ", "%.15g"%value, ";"] 
    137     else: 
    138         # extend constant arrays to a multiple of 4; not sure if this 
    139         # is necessary, but some OpenCL targets broke if the number 
    140         # of parameters in the parameter table was not a multiple of 4, 
    141         # so do it for all constant arrays to be safe. 
    142         if len(value)%4 != 0: 
    143             value = list(value) + [0.]*(4 - len(value)%4) 
    144         elements = ["%.15g"%v for v in value] 
    145         parts = ["double ", name, "[]", " = ", 
    146                  "{\n   ", ", ".join(elements), "\n};"] 
    147     return "".join(parts) 
    148  
    149  
    150 # Modified from the following: 
    151 # 
    152 #    http://code.activestate.com/recipes/578272-topological-sort/ 
    153 #    Copyright (C) 2012 Sam Denton 
    154 #    License: MIT 
    155 def ordered_dag(dag): 
    156     # type: (Dict[T, Set[T]]) -> Iterator[T] 
    157     dag = dag.copy() 
    158  
    159     # make leaves depend on the empty set 
    160     leaves = reduce(set.union, dag.values()) - set(dag.keys()) 
    161     dag.update({node: set() for node in leaves}) 
    162     while True: 
    163         leaves = set(node for node, links in dag.items() if not links) 
    164         if not leaves: 
    165             break 
    166         for node in leaves: 
    167             yield node 
    168         dag = {node: (links-leaves) 
    169                for node, links in dag.items() if node not in leaves} 
    170     if dag: 
    171         raise ValueError("Cyclic dependes exists amongst these items:\n%s" 
    172                             % ", ".join(str(node) for node in dag.keys())) 
  • sasmodels/modelinfo.py

    r67cc0ff rc01ed3e  
    810810    info.sesans = getattr(kernel_module, 'sesans', None) # type: ignore 
    811811    # Default single and opencl to True for C models.  Python models have callable Iq. 
    812     info.opencl = getattr(kernel_module, 'opencl', not callable(info.Iq)) 
    813     info.single = getattr(kernel_module, 'single', not callable(info.Iq)) 
    814812    info.random = getattr(kernel_module, 'random', None) 
    815813 
     
    827825        except Exception as exc: 
    828826            logger.warn(str(exc) + " while converting %s from C to python"%name) 
     827 
     828    # Needs to come after autoc.convert since the Iq symbol may have been 
     829    # converted from python to C 
     830    info.opencl = getattr(kernel_module, 'opencl', not callable(info.Iq)) 
     831    info.single = getattr(kernel_module, 'single', not callable(info.Iq)) 
    829832 
    830833    if callable(info.Iq) and parameters.has_2d: 
  • sasmodels/models/_cylpy.py

    r67cc0ff rc01ed3e  
    140140py2c = True 
    141141 
     142# TODO: "#define INVALID (expr)" is not supported 
    142143def invalid(v): 
    143144    return v.radius < 0 or v.length < 0 
     
    206207            phi_pd=10, phi_pd_n=5) 
    207208 
    208 qx, qy = 0.2 * np.cos(2.5), 0.2 * np.sin(2.5) 
     209qx, qy = 0.2 * cos(2.5), 0.2 * sin(2.5) 
    209210# After redefinition of angles, find new tests values.  Was 10 10 in old coords 
    210211tests = [ 
  • sasmodels/py2c.py

    r7b1dcf9 rc01ed3e  
    1  
    21""" 
    3     codegen 
    4     ~~~~~~~ 
    5  
    6     Extension to ast that allow ast -> python code generation. 
    7  
    8     :copyright: Copyright 2008 by Armin Ronacher. 
    9     :license: BSD. 
    10 """ 
    11 """ 
     2    py2c 
     3    ~~~~ 
     4 
     5    Convert simple numeric python code into C code. 
     6 
     7    The translate() function works on 
     8 
    129    Variables definition in C 
    1310    ------------------------- 
    14     Defining variables within the Translate function is a bit of a guess work, 
    15     using following rules. 
     11    Defining variables within the translate function is a bit of a guess work, 
     12    using following rules: 
    1613    *   By default, a variable is a 'double'. 
    1714    *   Variable in a for loop is an int. 
     
    1916        variable within the brackets is integer. For example, in the 
    2017        reference 'var1[var2]', var1 is a double array, and var2 is an integer. 
    21     *   Assignment to an argument makes that argument an array, and the index in 
    22         that assignment is 0. 
    23         For example, the following python code 
     18    *   Assignment to an argument makes that argument an array, and the index 
     19        in that assignment is 0. 
     20        For example, the following python code:: 
    2421            def func(arg1, arg2): 
    2522                arg2 = 17. 
    26         is translated to the following C code 
     23        is translated to the following C code:: 
    2724            double func(double arg1) 
    2825            { 
    2926                arg2[0] = 17.0; 
    3027            } 
    31         For example, the following python code is translated to the following C code 
     28        For example, the following python code is translated to the 
     29        following C code:: 
     30 
    3231            def func(arg1, arg2):          double func(double arg1) { 
    3332                arg2 = 17.                      arg2[0] = 17.0; 
    3433                                            } 
    35     *   All functions are defined as double, even if there is no return statement. 
    36  
    37  
     34    *   All functions are defined as double, even if there is no 
     35        return statement. 
     36 
     37Based on codegen.py: 
     38 
     39    :copyright: Copyright 2008 by Armin Ronacher. 
     40    :license: BSD. 
     41""" 
     42""" 
    3843Update Notes 
    3944============ 
    40 11/22 14:15, O.E.   Each 'visit_*' method is to build a C statement string. It 
     4511/22/2017, O.E.   Each 'visit_*' method is to build a C statement string. It 
    4146                    shold insert 4 blanks per indentation level. 
    4247                    The 'body' method will combine all the strings, by adding 
     
    525712/07/2017, OE: Translation of integer division, '\\' in python, implemented 
    5358                in translate_integer_divide, called from visit_BinOp 
    54 12/07/2017, OE: C variable definition handled in 'define_C_Vars' 
     5912/07/2017, OE: C variable definition handled in 'define_c_vars' 
    5560              : Python integer division, '//', translated to C in 
    5661                'translate_integer_divide' 
     
    102107 
    103108 
    104 def to_source(node, func_name, constants=None): 
    105     """This function can convert a node tree back into python sourcecode. 
    106     This is useful for debugging purposes, especially if you're dealing with 
    107     custom asts not generated by python itself. 
    108  
    109     It could be that the sourcecode is evaluable when the AST itself is not 
    110     compilable / evaluable.  The reason for this is that the AST contains some 
    111     more data than regular sourcecode does, which is dropped during 
    112     conversion. 
    113  
    114     Each level of indentation is replaced with `indent_with`.  Per default this 
    115     parameter is equal to four spaces as suggested by PEP 8, but it might be 
    116     adjusted to match the application's styleguide. 
    117  
    118     If `add_line_information` is set to `True` comments for the line numbers 
    119     of the nodes are added to the output.  This can be used to spot wrong line 
    120     number information of statement nodes. 
     109def to_source(tree, constants=None, fname=None, lineno=0): 
    121110    """ 
    122     generator = SourceGenerator(' ' * 4, False, constants) 
    123     generator.visit(node) 
    124  
    125     return ''.join(generator.c_proc) 
     111    This function can convert a syntax tree into C sourcecode. 
     112    """ 
     113    generator = SourceGenerator(constants=constants, fname=fname, lineno=lineno) 
     114    generator.visit(tree) 
     115    c_code = "\n".join(generator.c_proc) 
     116    return c_code 
    126117 
    127118def isevaluable(s): 
     
    129120        eval(s) 
    130121        return True 
    131     except: 
     122    except Exception: 
    132123        return False 
    133124 
     
    138129    """ 
    139130 
    140     def __init__(self, indent_with, add_line_information=False, constants=None): 
     131    def __init__(self, indent_with="    ", add_line_information=False, 
     132                 constants=None, fname=None, lineno=0): 
    141133        self.result = [] 
    142134        self.indent_with = indent_with 
     
    144136        self.indentation = 0 
    145137        self.new_lines = 0 
     138 
     139        # for C 
    146140        self.c_proc = [] 
    147         # for C 
    148141        self.signature_line = 0 
    149142        self.arguments = [] 
    150         self.name = "" 
     143        self.current_function = "" 
     144        self.fname = fname 
     145        self.lineno_offset = lineno 
    151146        self.warnings = [] 
    152         self.Statements = [] 
     147        self.statements = [] 
    153148        self.current_statement = "" 
    154         self.strMethodSignature = "" 
    155         self.C_Vars = [] 
    156         self.C_IntVars = [] 
    157         self.MathIncludeed = False 
    158         self.C_Pointers = [] 
    159         self.C_DclPointers = [] 
    160         self.C_Functions = [] 
    161         self.C_Vectors = [] 
    162         self.C_Constants = constants 
    163         self.SubRef = False 
    164         self.InSubscript = False 
    165         self.Tuples = [] 
     149        # TODO: use set rather than list for c_vars, ... 
     150        self.c_vars = [] 
     151        self.c_int_vars = [] 
     152        self.c_pointers = [] 
     153        self.c_dcl_pointers = [] 
     154        self.c_functions = [] 
     155        self.c_vectors = [] 
     156        self.c_constants = constants if constants is not None else {} 
     157        self.in_subref = False 
     158        self.in_subscript = False 
     159        self.tuples = [] 
    166160        self.required_functions = [] 
    167161        self.is_sequence = False 
     
    176170        self.result.append(x) 
    177171 
    178     def write_c(self, x): 
    179         self.current_statement += x 
    180  
    181     def add_c_line(self, x): 
    182         string = '' 
    183         for _ in range(self.indentation): 
    184             string += ("    ") 
    185         string += str(x) 
    186         self.c_proc.append(str(string + "\n")) 
    187         x = '' 
     172    def write_c(self, statement): 
     173        # TODO: build up as a list rather than adding to string 
     174        self.current_statement += statement 
     175 
     176    def add_c_line(self, line): 
     177        indentation = self.indent_with * self.indentation 
     178        self.c_proc.append("".join((indentation, line, "\n"))) 
    188179 
    189180    def add_current_line(self): 
     
    192183            self.current_statement = '' 
    193184 
    194     def AddUniqueVar(self, new_var): 
    195         if new_var not in self.C_Vars: 
    196             self.C_Vars.append(str(new_var)) 
    197  
    198     def WriteSincos(self, node): 
     185    def add_unique_var(self, new_var): 
     186        if new_var not in self.c_vars: 
     187            self.c_vars.append(str(new_var)) 
     188 
     189    def write_sincos(self, node): 
    199190        angle = str(node.args[0].id) 
    200191        self.write_c(node.args[1].id + " = sin(" + angle + ");") 
     
    203194        self.add_current_line() 
    204195        for arg in node.args: 
    205             self.AddUniqueVar(arg.id) 
     196            self.add_unique_var(arg.id) 
    206197 
    207198    def newline(self, node=None, extra=0): 
     
    211202            self.new_lines = 1 
    212203        if self.current_statement: 
    213             self.Statements.append(self.current_statement) 
     204            self.statements.append(self.current_statement) 
    214205            self.current_statement = '' 
    215206 
     
    278269            self.visit(node.msg) 
    279270 
    280     def define_C_Vars(self, target): 
     271    def define_c_vars(self, target): 
    281272        if hasattr(target, 'id'): 
    282273        # a variable is considered an array if it apears in the agrument list 
     
    287278        #  return 
    288279        # 
    289             if target.id not in self.C_Vars: 
     280            if target.id not in self.c_vars: 
    290281                if target.id in self.arguments: 
    291282                    idx = self.arguments.index(target.id) 
    292283                    new_target = self.arguments[idx] + "[0]" 
    293                     if new_target not in self.C_Pointers: 
     284                    if new_target not in self.c_pointers: 
    294285                        target.id = new_target 
    295                         self.C_Pointers.append(self.arguments[idx]) 
     286                        self.c_pointers.append(self.arguments[idx]) 
    296287                else: 
    297                     self.C_Vars.append(target.id) 
     288                    self.c_vars.append(target.id) 
    298289 
    299290    def add_semi_colon(self): 
     
    308299            if idx: 
    309300                self.write_c(' = ') 
    310             self.define_C_Vars(target) 
     301            self.define_c_vars(target) 
    311302            self.visit(target) 
    312         if self.Tuples: 
    313             tplTargets = list(self.Tuples) 
    314             del self.Tuples[:] 
     303        # Capture assigned tuple names, if any 
     304        targets = self.tuples[:] 
     305        del self.tuples[:] 
    315306        self.write_c(' = ') 
    316307        self.is_sequence = False 
     
    319310        self.add_semi_colon() 
    320311        self.add_current_line() 
    321         for n, item in enumerate(self.Tuples): 
    322             self.visit(tplTargets[n]) 
     312        # Assign tuples to tuples, if any 
     313        # TODO: doesn't handle swap:  a,b = b,a 
     314        for target, item in zip(targets, self.tuples): 
     315            self.visit(target) 
    323316            self.write_c(' = ') 
    324317            self.visit(item) 
     
    328321            for target in node.targets: 
    329322                if hasattr(target, 'id'): 
    330                     if target.id in self.C_Vars and target.id not in self.C_DclPointers: 
    331                         if target.id not in self.C_DclPointers: 
    332                             self.C_DclPointers.append(target.id) 
    333                             if target.id in self.C_Vars: 
    334                                 self.C_Vars.remove(target.id) 
     323                    if target.id in self.c_vars and target.id not in self.c_dcl_pointers: 
     324                        if target.id not in self.c_dcl_pointers: 
     325                            self.c_dcl_pointers.append(target.id) 
     326                            if target.id in self.c_vars: 
     327                                self.c_vars.remove(target.id) 
    335328        self.current_statement = '' 
    336329 
    337330    def visit_AugAssign(self, node): 
    338         if node.target.id not in self.C_Vars: 
     331        if node.target.id not in self.c_vars: 
    339332            if node.target.id not in self.arguments: 
    340                 self.C_Vars.append(node.target.id) 
     333                self.c_vars.append(node.target.id) 
    341334        self.visit(node.target) 
    342335        self.write_c(' ' + BINOP_SYMBOLS[type(node.op)] + '= ') 
     
    346339 
    347340    def visit_ImportFrom(self, node): 
     341        return  # import ignored 
    348342        self.newline(node) 
    349343        self.write_python('from %s%s import ' %('.' * node.level, node.module)) 
     
    354348 
    355349    def visit_Import(self, node): 
     350        return  # import ignored 
    356351        self.newline(node) 
    357352        for item in node.names: 
     
    363358        self.generic_visit(node) 
    364359 
    365     def listToDeclare(self, vars): 
    366         return ", ".join(vars) 
    367  
    368     def write_C_Pointers(self, start_var): 
    369         if self.C_DclPointers: 
     360    def write_c_pointers(self, start_var): 
     361        if self.c_dcl_pointers: 
    370362            var_list = [] 
    371             for c_ptr in self.C_DclPointers: 
    372                 if(len(vars) > 0): 
    373                     vars += ", " 
     363            for c_ptr in self.c_dcl_pointers: 
    374364                if c_ptr not in self.arguments: 
    375365                    var_list.append("*" + c_ptr) 
    376                 if c_ptr in self.C_Vars: 
    377                     self.C_Vars.remove(c_ptr) 
     366                if c_ptr in self.c_vars: 
     367                    self.c_vars.remove(c_ptr) 
    378368            if var_list: 
    379369                c_dcl = "    double " + ", ".join(var_list) + ";\n" 
     
    382372        return start_var 
    383373 
    384     def insert_C_Vars(self, start_var): 
    385         fLine = False 
    386         start_var = self.write_C_Pointers(start_var) 
    387         if self.C_IntVars: 
    388             for var in self.C_IntVars: 
    389                 if var in self.C_Vars: 
    390                     self.C_Vars.remove(var) 
    391             s = self.listToDeclare(self.C_IntVars) 
    392             self.c_proc.insert(start_var, "    int " + s + ";\n") 
    393             fLine = True 
     374    def insert_c_vars(self, start_var): 
     375        have_decls = False 
     376        start_var = self.write_c_pointers(start_var) 
     377        if self.c_int_vars: 
     378            for var in self.c_int_vars: 
     379                if var in self.c_vars: 
     380                    self.c_vars.remove(var) 
     381            decls = ", ".join(self.c_int_vars) 
     382            self.c_proc.insert(start_var, "    int " + decls + ";\n") 
     383            have_decls = True 
    394384            start_var += 1 
    395385 
    396         if self.C_Vars: 
    397             s = self.listToDeclare(self.C_Vars) 
    398             self.c_proc.insert(start_var, "    double " + s + ";\n") 
    399             fLine = True 
     386        if self.c_vars: 
     387            decls = ", ".join(self.c_vars) 
     388            self.c_proc.insert(start_var, "    double " + decls + ";\n") 
     389            have_decls = True 
    400390            start_var += 1 
    401391 
    402         if self.C_Vectors: 
    403             s = self.listToDeclare(self.C_Vectors) 
    404             for n in range(len(self.C_Vectors)): 
    405                 name = "vec" + str(n+1) 
    406                 c_dcl = "    double " + name + "[] = {" + self.C_Vectors[n] + "};" 
    407                 self.c_proc.insert(start_var, c_dcl + "\n") 
     392        if self.c_vectors: 
     393            for vec_number, vec_value  in enumerate(self.c_vectors): 
     394                name = "vec" + str(vec_number + 1) 
     395                decl = "    double " + name + "[] = {" + vec_value + "};" 
     396                self.c_proc.insert(start_var, decl + "\n") 
    408397                start_var += 1 
    409398 
    410         del self.C_Vars[:] 
    411         del self.C_IntVars[:] 
    412         del self.C_Vectors[:] 
    413         del self.C_Pointers[:] 
    414         self.C_DclPointers 
    415         if fLine: 
     399        del self.c_vars[:] 
     400        del self.c_int_vars[:] 
     401        del self.c_vectors[:] 
     402        del self.c_pointers[:] 
     403        del self.c_dcl_pointers[:] 
     404        if have_decls: 
    416405            self.c_proc.insert(start_var, "\n") 
    417406 
    418     def InsertSignature(self): 
     407    def insert_signature(self): 
    419408        arg_decls = [] 
    420409        for arg in self.arguments: 
    421410            decl = "double " + arg 
    422             if arg in self.C_Pointers: 
     411            if arg in self.c_pointers: 
    423412                decl += "[]" 
    424413            arg_decls.append(decl) 
    425414        args_str = ", ".join(arg_decls) 
    426         self.strMethodSignature = 'double ' + self.name + '(' + args_str + ")" 
     415        method_sig = 'double ' + self.current_function + '(' + args_str + ")" 
    427416        if self.signature_line >= 0: 
    428             self.c_proc.insert(self.signature_line, self.strMethodSignature) 
     417            self.c_proc.insert(self.signature_line, method_sig) 
    429418 
    430419    def visit_FunctionDef(self, node): 
     420        if self.current_function: 
     421            self.unsupported(node, "function within a function") 
     422        self.current_function = node.name 
     423 
    431424        self.newline(extra=1) 
    432425        self.decorators(node) 
    433426        self.newline(node) 
    434427        self.arguments = [] 
    435         self.name = node.name 
    436         #if self.name not in self.required_functions[0]: 
    437         #   return 
    438         #print("Parsing '" + self.name + "'") 
    439  
    440428        self.visit(node.args) 
    441429        # for C 
    442430        self.signature_line = len(self.c_proc) 
    443         #self.add_c_line(self.strMethodSignature) 
    444431        self.add_c_line("\n{") 
    445432        start_vars = len(self.c_proc) + 1 
    446433        self.body(node.body) 
    447434        self.add_c_line("}\n") 
    448         self.InsertSignature() 
    449         self.insert_C_Vars(start_vars) 
    450         self.C_Pointers = [] 
     435        self.insert_signature() 
     436        self.insert_c_vars(start_vars) 
     437        self.c_pointers = [] 
     438        self.current_function = "" 
    451439 
    452440    def visit_ClassDef(self, node): 
     
    511499                break 
    512500 
    513     def getNodeLineNo(self, node): 
    514         line_number = -1 
    515         if hasattr(node, 'value'): 
    516             line_number = node.value.lineno 
    517         elif hasattr(node, 'iter'): 
    518             if hasattr(node.iter, 'lineno'): 
    519                 line_number = node.iter.lineno 
    520         return line_number 
    521  
    522     def GetNodeAsString(self, node): 
    523         res = '' 
    524         if hasattr(node, 'n'): 
    525             res = str(node.n) 
    526         elif hasattr(node, 'id'): 
    527             res = node.id 
    528         return res 
    529  
    530     def GetForRange(self, node): 
     501    def get_for_range(self, node): 
    531502        stop = "" 
    532503        start = '0' 
     
    565536                    iterator = self.current_statement 
    566537                    self.current_statement = '' 
    567                     if iterator not in self.C_IntVars: 
    568                         self.C_IntVars.append(iterator) 
    569                     start, stop, step = self.GetForRange(node) 
     538                    if iterator not in self.c_int_vars: 
     539                        self.c_int_vars.append(iterator) 
     540                    start, stop, step = self.get_for_range(node) 
    570541                    self.write_c("for(" + iterator + "=" + str(start) + 
    571542                                 " ; " + iterator + " < " + str(stop) + 
     
    575546                    fForDone = True 
    576547        if not fForDone: 
    577             line_number = self.getNodeLineNo(node) 
     548            # Generate the statement that is causing the error 
    578549            self.current_statement = '' 
    579550            self.write_c('for ') 
     
    582553            self.visit(node.iter) 
    583554            self.write_c(':') 
    584             errStr = "Conversion Error in function " + self.name + ", Line #" + str(line_number) 
    585             errStr += "\nPython for expression not supported: '" + self.current_statement + "'" 
    586             raise Exception(errStr) 
     555            # report the error 
     556            self.unsupported("unsupported " + self.current_statement) 
    587557 
    588558    def visit_While(self, node): 
     
    594564 
    595565    def visit_With(self, node): 
     566        self.unsupported(node) 
    596567        self.newline(node) 
    597568        self.write_python('with ') 
     
    605576    def visit_Pass(self, node): 
    606577        self.newline(node) 
    607         self.write_python('pass') 
     578        #self.write_python('pass') 
    608579 
    609580    def visit_Print(self, node): 
     581        # TODO: print support would be nice, though hard to do 
     582        self.unsupported(node) 
    610583        # CRUFT: python 2.6 only 
    611584        self.newline(node) 
     
    625598 
    626599    def visit_Delete(self, node): 
     600        self.unsupported(node) 
    627601        self.newline(node) 
    628602        self.write_python('del ') 
     
    633607 
    634608    def visit_TryExcept(self, node): 
     609        self.unsupported(node) 
    635610        self.newline(node) 
    636611        self.write_python('try:') 
     
    640615 
    641616    def visit_TryFinally(self, node): 
     617        self.unsupported(node) 
    642618        self.newline(node) 
    643619        self.write_python('try:') 
     
    648624 
    649625    def visit_Global(self, node): 
     626        self.unsupported(node) 
    650627        self.newline(node) 
    651628        self.write_python('global ' + ', '.join(node.names)) 
     
    676653 
    677654    def visit_Raise(self, node): 
     655        self.unsupported(node) 
    678656        # CRUFT: Python 2.6 / 3.0 compatibility 
    679657        self.newline(node) 
     
    697675 
    698676    def visit_Attribute(self, node): 
    699         errStr = "Conversion Error in function " + self.name + ", Line #" + str(node.value.lineno) 
    700         errStr += "\nPython expression not supported: '" + node.value.id + "." + node.attr + "'" 
    701         raise Exception(errStr) 
     677        self.unsupported(node, "attribute reference a.b not supported") 
    702678        self.visit(node.value) 
    703679        self.write_python('.' + node.attr) 
     
    711687                want_comma.append(True) 
    712688        if hasattr(node.func, 'id'): 
    713             if node.func.id not in self.C_Functions: 
    714                 self.C_Functions.append(node.func.id) 
     689            if node.func.id not in self.c_functions: 
     690                self.c_functions.append(node.func.id) 
    715691            if node.func.id == 'abs': 
    716692                self.write_c("fabs ") 
     
    718694                self.write_c('(int) ') 
    719695            elif node.func.id == "SINCOS": 
    720                 self.WriteSincos(node) 
     696                self.write_sincos(node) 
    721697                return 
    722698            else: 
     
    724700        else: 
    725701            self.visit(node.func) 
    726 #self.C_Functions 
    727702        self.write_c('(') 
    728703        for arg in node.args: 
     
    748723    def visit_Name(self, node): 
    749724        self.write_c(node.id) 
    750         if node.id in self.C_Pointers and not self.SubRef: 
     725        if node.id in self.c_pointers and not self.in_subref: 
    751726            self.write_c("[0]") 
    752727        name = "" 
     
    756731        else: 
    757732            name = node.id 
    758         # add variable to C_Vars if it ins't there yet, not an argument and not a number 
    759         if (name not in self.C_Functions and name not in self.C_Vars and 
    760                 name not in self.C_IntVars and name not in self.arguments and 
    761                 name not in self.C_Constants and not name.isdigit()): 
    762             if self.InSubscript: 
    763                 self.C_IntVars.append(node.id) 
     733        # add variable to c_vars if it ins't there yet, not an argument and not a number 
     734        if (name not in self.c_functions and name not in self.c_vars and 
     735                name not in self.c_int_vars and name not in self.arguments and 
     736                name not in self.c_constants and not name.isdigit()): 
     737            if self.in_subscript: 
     738                self.c_int_vars.append(node.id) 
    764739            else: 
    765                 self.C_Vars.append(node.id) 
     740                self.c_vars.append(node.id) 
    766741 
    767742    def visit_Str(self, node): 
     
    777752        for idx, item in enumerate(node.elts): 
    778753            if idx: 
    779                 self.Tuples.append(item) 
     754                self.tuples.append(item) 
    780755            else: 
    781756                self.visit(item) 
     
    793768                    s += str(item.n) 
    794769            if s: 
    795                 self.C_Vectors.append(s) 
    796                 vec_name = "vec"  + str(len(self.C_Vectors)) 
     770                self.c_vectors.append(s) 
     771                vec_name = "vec"  + str(len(self.c_vectors)) 
    797772                self.write_c(vec_name) 
    798773        return visit 
     
    803778 
    804779    def visit_Dict(self, node): 
     780        self.unsupported(node) 
    805781        self.write_python('{') 
    806782        for idx, (key, value) in enumerate(zip(node.keys, node.values)): 
     
    901877 
    902878    def visit_Subscript(self, node): 
    903         if node.value.id not in self.C_Constants: 
    904             if node.value.id not in self.C_Pointers: 
    905                 self.C_Pointers.append(node.value.id) 
    906         self.SubRef = True 
     879        if node.value.id not in self.c_constants: 
     880            if node.value.id not in self.c_pointers: 
     881                self.c_pointers.append(node.value.id) 
     882        self.in_subref = True 
    907883        self.visit(node.value) 
    908         self.SubRef = False 
     884        self.in_subref = False 
    909885        self.write_c('[') 
    910         self.InSubscript = True 
     886        self.in_subscript = True 
    911887        self.visit(node.slice) 
    912         self.InSubscript = False 
     888        self.in_subscript = False 
    913889        self.write_c(']') 
    914890 
     
    931907 
    932908    def visit_Yield(self, node): 
     909        self.unsupported(node) 
    933910        self.write_python('yield ') 
    934911        self.visit(node.value) 
    935912 
    936913    def visit_Lambda(self, node): 
     914        self.unsupported(node) 
    937915        self.write_python('lambda ') 
    938916        self.visit(node.args) 
     
    941919 
    942920    def visit_Ellipsis(self, node): 
     921        self.unsupported(node) 
    943922        self.write_python('Ellipsis') 
    944923 
     
    960939 
    961940    def visit_DictComp(self, node): 
     941        self.unsupported(node) 
    962942        self.write_python('{') 
    963943        self.visit(node.key) 
     
    988968 
    989969    def visit_alias(self, node): 
     970        self.unsupported(node) 
    990971        self.write_python(node.name) 
    991972        if node.asname is not None: 
     
    1005986    def visit_arguments(self, node): 
    1006987        self.signature(node) 
     988 
     989    def unsupported(self, node, message=None): 
     990        if hasattr(node, "value"): 
     991            lineno = node.value.lineno 
     992        elif hasattr(node, "iter"): 
     993            lineno = node.iter.lineno 
     994        else: 
     995            #print(dir(node)) 
     996            lineno = 0 
     997 
     998        lineno += self.lineno_offset 
     999        if self.fname: 
     1000            location = "%s(%d)" % (self.fname, lineno) 
     1001        else: 
     1002            location = "%d" % (self.fname, lineno) 
     1003        if self.current_function: 
     1004            location += ", function %s" % self.current_function 
     1005        if message is None: 
     1006            message = node.__class__.__name__ + " syntax not supported" 
     1007        raise SyntaxError("[%s] %s" % (location, message)) 
    10071008 
    10081009def print_function(f=None): 
     
    10181019        print(tree_source) 
    10191020 
    1020 def translate(functions, constants=0): 
     1021def define_constant(name, value, block_size=1): 
     1022    # type: (str, any, int) -> str 
     1023    """ 
     1024    Convert a python constant into a C constant of the same name. 
     1025 
     1026    Returns the C declaration of the constant as a string, possibly containing 
     1027    line feeds.  The string will not be indented. 
     1028 
     1029    Supports int, double and sequences of double. 
     1030    """ 
     1031    const = "constant "  # OpenCL needs globals to be constant 
     1032    if isinstance(value, int): 
     1033        parts = [const + "int ", name, " = ", "%d"%value, ";"] 
     1034    elif isinstance(value, float): 
     1035        parts = [const + "double ", name, " = ", "%.15g"%value, ";"] 
     1036    else: 
     1037        try: 
     1038            len(value) 
     1039        except TypeError: 
     1040            raise TypeError("constant %s must be int, float or [float, ...]"%name) 
     1041        # extend constant arrays to a multiple of 4; not sure if this 
     1042        # is necessary, but some OpenCL targets broke if the number 
     1043        # of parameters in the parameter table was not a multiple of 4, 
     1044        # so do it for all constant arrays to be safe. 
     1045        if len(value)%block_size != 0: 
     1046            value = list(value) + [0.]*(block_size - len(value)%block_size) 
     1047        elements = ["%.15g"%v for v in value] 
     1048        parts = [const + "double ", name, "[]", " = ", 
     1049                 "{\n   ", ", ".join(elements), "\n};"] 
     1050 
     1051    return "".join(parts) 
     1052 
     1053 
     1054# Modified from the following: 
     1055# 
     1056#    http://code.activestate.com/recipes/578272-topological-sort/ 
     1057#    Copyright (C) 2012 Sam Denton 
     1058#    License: MIT 
     1059def ordered_dag(dag): 
     1060    # type: (Dict[T, Set[T]]) -> Iterator[T] 
     1061    """ 
     1062    Given a dag defined by a dictionary of {k1: [k2, ...]} yield keys 
     1063    in order such that every key occurs after the keys it depends upon. 
     1064 
     1065    This is an iterator not a sequence.  To reverse it use:: 
     1066 
     1067        reversed(tuple(ordered_dag(dag))) 
     1068 
     1069    Raise an error if there are any cycles. 
     1070 
     1071    Keys are arbitrary hashable values. 
     1072    """ 
     1073    # Local import to make the function stand-alone, and easier to borrow 
     1074    from functools import reduce 
     1075 
     1076    dag = dag.copy() 
     1077 
     1078    # make leaves depend on the empty set 
     1079    leaves = reduce(set.union, dag.values()) - set(dag.keys()) 
     1080    dag.update({node: set() for node in leaves}) 
     1081    while True: 
     1082        leaves = set(node for node, links in dag.items() if not links) 
     1083        if not leaves: 
     1084            break 
     1085        for node in leaves: 
     1086            yield node 
     1087        dag = {node: (links-leaves) 
     1088               for node, links in dag.items() if node not in leaves} 
     1089    if dag: 
     1090        raise ValueError("Cyclic dependes exists amongst these items:\n%s" 
     1091                         % ", ".join(str(node) for node in dag.keys())) 
     1092 
     1093def translate(functions, constants=None): 
     1094    # type: (List[(str, str, int)], Dict[str, any]) -> List[str] 
     1095    """ 
     1096    Convert a set of functions 
     1097    """ 
    10211098    snippets = [] 
    10221099    #snippets.append("#include <math.h>") 
    10231100    #snippets.append("") 
    1024     for source, fname, line_no in functions: 
    1025         line_directive = '#line %d "%s"'%(line_no, fname.replace('\\', '\\\\')) 
     1101    for source, fname, lineno in functions: 
     1102        line_directive = '#line %d "%s"'%(lineno, fname.replace('\\', '\\\\')) 
    10261103        snippets.append(line_directive) 
    10271104        tree = ast.parse(source) 
    1028         # in the future add filename, offset, constants 
    1029         c_code = to_source(tree, functions, constants) 
     1105        c_code = to_source(tree, constants=constants, fname=fname, lineno=lineno) 
    10301106        snippets.append(c_code) 
    10311107    return snippets 
     
    10511127    with open(fname_in, "r") as python_file: 
    10521128        code = python_file.read() 
    1053  
    1054     translation = translate([code, fname_in, 1])[0] 
     1129    name = "gauss" 
     1130    code = (code 
     1131            .replace(name+'.n', 'GAUSS_N') 
     1132            .replace(name+'.z', 'GAUSS_Z') 
     1133            .replace(name+'.w', 'GAUSS_W')) 
     1134 
     1135    translation = translate([(code, fname_in, 1)])[0] 
    10551136 
    10561137    with open(fname_out, "w") as file_out: 
Note: See TracChangeset for help on using the changeset viewer.