Changeset 8224d24 in sasmodels


Ignore:
Timestamp:
Dec 18, 2017 1:17:26 PM (7 years ago)
Author:
Paul Kienzle <pkienzle@…>
Children:
1941ec6
Parents:
1ddb794 (diff), ddfdb16 (diff)
Note: this is a merge changeset, the changes displayed below correspond to the merge itself.
Use the (diff) links above to see all the changes relative to each parent.
Message:

Merge branch 'master' of https://github.com/omereis/sasmodels

Location:
sasmodels
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • sasmodels/autoc.py

    r1ddb794 r8224d24  
    9292                    # not special: add function to translate stack 
    9393                    translate.append((name, obj)) 
    94             elif isinstance(obj, float): 
     94            elif isinstance(obj, (int, float, list, tuple, np.ndarray)): 
    9595                constants[name] = obj 
    96                 snippets.append('#line 1 "%s"' % escaped_filename) 
    97                 snippets.append("const double %s = %.15g;"%(name, obj)) 
    98             elif isinstance(obj, int): 
    99                 constants[name] = obj 
    100                 snippets.append('#line 1 "%s"' % escaped_filename) 
    101                 snippets.append("const int %s = %d;"%(name, obj)) 
    102             elif isinstance(obj, (list, tuple, np.ndarray)): 
    103                 constants[name] = obj 
    104                 # extend constant arrays to a multiple of 4; not sure if this 
    105                 # is necessary, but some OpenCL targets broke if the number 
    106                 # of parameters in the parameter table was not a multiple of 4, 
    107                 # so do it for all constant arrays to be safe. 
    108                 if len(obj)%4 != 0: 
    109                     obj = list(obj) + [0.]*(4-len(obj)) 
    110                 vals = ", ".join("%.15g"%v for v in obj) 
    111                 snippets.append('#line 1 "%s"' % escaped_filename) 
    112                 snippets.append("const double %s[] = {%s};" %(name, vals)) 
     96                # Claim all constants are declared on line 1 
     97                snippets.append('#line 1 "%s"'%escaped_filename) 
     98                snippets.append(define_constant(name, obj)) 
    11399            elif isinstance(obj, special.Gauss): 
    114                 constants["GAUSS_N"] = obj.n 
    115                 constants["GAUSS_Z"] = obj.z 
    116                 constants["GAUSS_W"] = obj.w 
     100                #constants["GAUSS_N"] = obj.n 
     101                #constants["GAUSS_Z"] = obj.z 
     102                #constants["GAUSS_W"] = obj.w 
    117103                libs.append('lib/gauss%d.c'%obj.n) 
    118104                source = (source.replace(name+'.n', 'GAUSS_N') 
     
    142128    info.Iq = info.Iqxy = info.form_volume = None 
    143129 
     130def define_constant(name, value): 
     131    if isinstance(value, int): 
     132        parts = ["int ", name, " = ", "%d"%value, ";"] 
     133    elif isinstance(value, float): 
     134        parts = ["double ", name, " = ", "%.15g"%value, ";"] 
     135    else: 
     136        # extend constant arrays to a multiple of 4; not sure if this 
     137        # is necessary, but some OpenCL targets broke if the number 
     138        # of parameters in the parameter table was not a multiple of 4, 
     139        # so do it for all constant arrays to be safe. 
     140        if len(value)%4 != 0: 
     141            value = list(value) + [0.]*(4 - len(value)%4) 
     142        elements = ["%.15g"%v for v in value] 
     143        parts = ["double ", name, "[]", " = ", 
     144                 "{\n   ", ", ".join(elements), "\n};"] 
     145    return "".join(parts) 
     146 
    144147 
    145148# Modified from the following: 
  • sasmodels/py2c.py

    r1ddb794 r8224d24  
    575712/15/2017, OE: Precedence maintained by writing opening and closing 
    5858                parenthesesm '(',')', in procedure 'visit_BinOp'. 
     5912/18/2017, OE: Added call to 'add_current_line()' at the beginning 
     60                of visit_Return 
     61 
    5962""" 
    6063import ast 
     
    99102 
    100103 
    101 #def to_source(node, indent_with=' ' * 4, add_line_information=False): 
    102 def to_source(node, func_name): 
     104def to_source(node, func_name, constants=None): 
    103105    """This function can convert a node tree back into python sourcecode. 
    104106    This is useful for debugging purposes, especially if you're dealing with 
     
    118120    number information of statement nodes. 
    119121    """ 
    120     generator = SourceGenerator(' ' * 4, False) 
    121 #    generator.required_functions = func_name 
     122    generator = SourceGenerator(' ' * 4, False, constants) 
    122123    generator.visit(node) 
    123124 
    124 #    return ''.join(generator.result) 
    125125    return ''.join(generator.c_proc) 
    126126 
     
    138138    """ 
    139139 
    140     def __init__(self, indent_with, add_line_information=False): 
     140    def __init__(self, indent_with, add_line_information=False, constants=None): 
    141141        self.result = [] 
    142142        self.indent_with = indent_with 
     
    160160        self.C_Functions = [] 
    161161        self.C_Vectors = [] 
     162        self.C_Constants = constants 
    162163        self.SubRef = False 
    163164        self.InSubscript = False 
     
    262263                         % arg_name, str(default.n)) 
    263264                self.warnings.append(w_str) 
    264 #                self.write_python('=') 
    265 #                self.visit(default) 
    266265 
    267266    def decorators(self, node): 
     
    321320        self.visit(node.value) 
    322321        self.add_semi_colon() 
    323 #        self.write_c(';') 
    324322        self.add_current_line() 
    325323        for n, item in enumerate(self.Tuples): 
     
    347345        self.visit(node.value) 
    348346        self.add_semi_colon() 
    349 #        self.write_c(';') 
    350347        self.add_current_line() 
    351348 
     
    408405            fLine = True 
    409406            start_var += 1 
    410 #        if(len(self.C_IntVars) > 0): 
    411 #            s = self.listToDeclare(self.C_IntVars) 
    412 #            self.c_proc.insert(start_var, "    int " + s + ";\n") 
    413 #            fLine = True 
    414 #            start_var += 1 
    415407        if(len(self.C_Vectors) > 0): 
    416408            s = self.listToDeclare(self.C_Vectors) 
     
    446438 
    447439    def getMethodSignature(self): 
    448 #        args_str = ListToString(self.arguments) 
    449440        args_str = '' 
    450441        for n in range(len(self.arguments)): 
     
    453444                args_str += ", " 
    454445        return(args_str) 
    455 #        self.strMethodSignature = 'double ' + self.name + '(' + args_str + ")" 
    456446 
    457447    def InsertSignature(self): 
     
    473463        self.arguments = [] 
    474464        self.name = node.name 
    475 #        if self.name not in self.required_functions[0]: 
    476 #           return 
    477465        print("Parsing '" + self.name + "'") 
    478466        args_str = "" 
     
    480468        self.visit(node.args) 
    481469        self.getMethodSignature() 
    482 # for C 
    483470        self.signature_line = len(self.c_proc) 
    484 #        self.add_c_line(self.strMethodSignature) 
    485471        self.add_c_line("\n{") 
    486472        start_vars = len(self.c_proc) + 1 
     
    557543        if(hasattr(node,'value')): 
    558544            line_number = node.value.lineno 
    559         elif hasattr(node,'iter'): 
    560             if hasattr(node.iter,'lineno'): 
     545        elif hasattr(node, 'iter'): 
     546            if hasattr(node.iter, 'lineno'): 
    561547                line_number = node.iter.lineno 
    562548        return(line_number) 
     
    599585# Iterator name is in node.target.id. 
    600586        self.add_current_line() 
    601 #        if(len(self.current_statement) > 0): 
    602 #            self.add_c_line(self.current_statement) 
    603 #            self.current_statement = '' 
    604587        fForDone = False 
    605588        self.current_statement = '' 
     
    701684 
    702685    def visit_Return(self, node): 
    703         self.newline(node) 
     686        self.add_current_line() 
    704687        if node.value is None: 
    705688            self.write_c('return') 
     
    802785            name = node.id 
    803786#       add variable to C_Vars if it ins't there yet, not an argument and not a number 
    804         if((name not in self.C_Functions) and (name not in self.C_Vars) and \ 
     787        if ((name not in self.C_Functions) and (name not in self.C_Vars) and \ 
    805788            (name not in self.C_IntVars) and (name not in self.arguments) and \ 
    806             (name.isdigit() == False)): 
     789            (name not in self.C_Constants) and (name.isdigit() == False)): 
    807790            if(self.InSubscript): 
    808791                self.C_IntVars.append(node.id) 
     
    947930 
    948931    def visit_Subscript(self, node): 
    949         if(node.value.id not in self.C_Pointers): 
    950             self.C_Pointers.append(node.value.id) 
     932        if (node.value.id not in self.C_Constants): 
     933            if(node.value.id not in self.C_Pointers): 
     934                self.C_Pointers.append(node.value.id) 
    951935        self.SubRef = True 
    952936        self.visit(node.value) 
     
    10481032                self.visit(if_) 
    10491033 
    1050 #    def visit_excepthandler(self, node): 
    1051 #        self.newline(node) 
    1052 #        self.write_python('except') 
    1053 #        if node.type is not None: 
    1054 #            self.write_python(' ') 
    1055 #            self.visit(node.type) 
    1056 #            if node.name is not None: 
    1057 #                self.write_python(' as ') 
    1058 #                self.visit(node.name) 
    1059 #        self.body(node.body) 
    1060  
    10611034    def visit_arguments(self, node): 
    10621035        self.signature(node) 
    1063  
    1064 def Iq1(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17): 
    1065     z1 = z2 = z = abs(q - peak_pos) * lorentz_length 
    1066     if(q > p): 
    1067         q = p + 17 
    1068         p = q - 5 
    1069     z3 = -8 
    1070     inten = (porod_scale / q ** porod_exp 
    1071                 + lorentz_scale /(1 + z ** lorentz_exp)) 
    1072     return inten 
    1073  
    1074 def Iq(q, porod_scale, porod_exp, lorentz_scale, lorentz_length, peak_pos, lorentz_exp=17): 
    1075     z1 = z2 = z = abs(q - peak_pos) * lorentz_length 
    1076     if(q > p): 
    1077         q = p + 17 
    1078         p = q - 5 
    1079     elif(q == p): 
    1080         q = p * q 
    1081         q *= z1 
    1082         p = z1 
    1083     elif(q == 17): 
    1084         q = p * q - 17 
    1085     else: 
    1086         q += 7 
    1087     z3 = -8 
    1088     inten = (porod_scale / q ** porod_exp 
    1089                 + lorentz_scale /(1 + z ** lorentz_exp)) 
    1090     return inten 
    10911036 
    10921037def print_function(f=None): 
     
    11041049def translate(functions, constants=0): 
    11051050    snippets = [] 
     1051    #snippets.append("#include <math.h>") 
     1052    #snippets.append("") 
    11061053    for source, fname, line_no in functions: 
    1107         line_directive = '#line %d "%s"'%(line_no, fname.replace('\\','\\\\')) 
     1054        line_directive = '#line %d "%s"'%(line_no, fname.replace('\\', '\\\\')) 
    11081055        snippets.append(line_directive) 
    11091056        tree = ast.parse(source) 
    11101057        # in the future add filename, offset, constants 
    1111         c_code = to_source(tree, functions) 
     1058        c_code = to_source(tree, functions, constants) 
    11121059        snippets.append(c_code) 
    11131060    return snippets 
  • sasmodels/kerneldll.py

    r2d81cfe r1ddb794  
    185185        subprocess.check_output(command, shell=shell, stderr=subprocess.STDOUT) 
    186186    except subprocess.CalledProcessError as exc: 
    187         raise RuntimeError("compile failed.\n%s\n%s"%(command_str, exc.output)) 
     187        raise RuntimeError("compile failed.\n%s\n%s" 
     188                           % (command_str, exc.output.decode())) 
    188189    if not os.path.exists(output): 
    189190        raise RuntimeError("compile failed.  File is in %r"%source) 
Note: See TracChangeset for help on using the changeset viewer.