source: sasmodels/sasmodels/autoc.py @ 67cc0ff

Last change on this file since 67cc0ff was 67cc0ff, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

update to new sasmodels api

  • Property mode set to 100644
File size: 6.3 KB
Line 
1"""
2Automatically translate python models to C
3"""
4from __future__ import print_function
5
6import ast
7import inspect
8from functools import reduce
9
10import numpy as np
11
12from . import py2c
13from . import special
14
15# pylint: disable=unused-import
16try:
17    from types import ModuleType
18    #from .modelinfo import ModelInfo  # circular import
19except ImportError:
20    pass
21# pylint: enable=unused-import
22
23DEPENDENCY = {
24    'core_shell_kernel': ['lib/core_shell.c'],
25    'fractal_sq': ['lib/fractal_sq.c'],
26    'gfn4': ['lib/gfn.c'],
27    'polevl': ['lib/polevl.c'],
28    'p1evl': ['lib/polevl.c'],
29    'sas_2J1x_x': ['lib/polevl.c', 'lib/sas_J1.c'],
30    'sas_3j1x_x': ['lib/sas_3j1x_x.c'],
31    'sas_erf': ['lib/polevl.c', 'lib/sas_erf.c'],
32    'sas_erfc': ['lib/polevl.c', 'lib/sas_erf.c'],
33    'sas_gamma': ['lib/sas_gamma.c'],
34    'sas_J0': ['lib/polevl.c', 'lib/sas_J0.c'],
35    'sas_J1': ['lib/polevl.c', 'lib/sas_J1.c'],
36    'sas_JN': ['lib/polevl.c', 'lib/sas_J0.c', 'lib/sas_J1.c', 'lib/sas_JN.c'],
37    'sas_Si': ['lib/Si.c'],
38}
39
40DEFINES = frozenset("M_PI M_PI_2 M_PI_4 M_SQRT1_2 M_E NAN INFINITY M_PI_180 M_4PI_3".split())
41
42def convert(info, module):
43    # type: ("ModelInfo", ModuleType) -> bool
44    """
45    convert Iq, Iqxy and form_volume to c
46    """
47    # Check if there is already C code
48    if info.source or info.c_code is not None:
49        return
50
51    public_methods = "Iq", "Iqac", "Iqabc", "Iqxy", "form_volume"
52
53    tagged = [] # type: List[str]
54    translate = [] # type: List[Callable]
55    for function_name in public_methods:
56        function = getattr(info, function_name)
57        if callable(function):
58            if getattr(function, 'vectorized', None):
59                return  # Don't try to translate vectorized code
60            tagged.append(function_name)
61            translate.append((function_name, function))
62    if not translate:
63        # nothing to translate---maybe Iq, etc. are already C snippets?
64        return
65
66    libs = []  # type: List[str]
67    snippets = []  # type: List[str]
68    constants = {} # type: Dict[str, Any]
69    code = {}  # type: Dict[str, str]
70    depends = {}  # type: Dict[str, List[str]]
71    while translate:
72        function_name, function = translate.pop(0)
73        filename = function.__code__.co_filename
74        escaped_filename = filename.replace('\\', '\\\\')
75        offset = function.__code__.co_firstlineno
76        refs = function.__code__.co_names
77        depends[function_name] = set(refs)
78        source = inspect.getsource(function)
79        for name in refs:
80            if name in tagged or name in DEFINES:
81                continue
82            tagged.append(name)
83            obj = getattr(module, name, None)
84            if obj is None:
85                pass # ignore unbound variables for now
86                #raise ValueError("global %s is not defined" % name)
87            elif callable(obj):
88                if getattr(special, name, None):
89                    # special symbol: look up depenencies
90                    libs.extend(DEPENDENCY.get(name, []))
91                else:
92                    # not special: add function to translate stack
93                    translate.append((name, obj))
94            elif isinstance(obj, (int, float, list, tuple, np.ndarray)):
95                constants[name] = obj
96                # Claim all constants are declared on line 1
97                snippets.append('#line 1 "%s"'%escaped_filename)
98                snippets.append(define_constant(name, obj))
99            elif isinstance(obj, special.Gauss):
100                for var, value in zip(("N", "Z", "W"), (obj.n, obj.z, obj.w)):
101                    var = "GAUSS_"+var
102                    constants[var] = value
103                    snippets.append('#line 1 "%s"'%escaped_filename)
104                    snippets.append(define_constant(var, value))
105                #libs.append('lib/gauss%d.c'%obj.n)
106                source = (source.replace(name+'.n', 'GAUSS_N')
107                          .replace(name+'.z', 'GAUSS_Z')
108                          .replace(name+'.w', 'GAUSS_W'))
109            else:
110                raise TypeError("Could not convert global %s of type %s"
111                                % (name, str(type(obj))))
112
113        # add (possibly modified) source to set of functions to compile
114        code[function_name] = (source, filename, offset)
115
116    # remove duplicates from the dependecy list
117    unique_libs = []
118    for filename in libs:
119        if filename not in unique_libs:
120            unique_libs.append(filename)
121
122    # translate source
123    ordered_code = [code[name] for name in ordered_dag(depends) if name in code]
124    functions = py2c.translate(ordered_code, constants)
125    snippets.extend(functions)
126
127    # update model info
128    info.source = unique_libs
129    info.c_code = "\n".join(snippets)
130    info.Iq = info.Iqac = info.Iqabc = info.Iqxy = info.form_volume = None
131
132def define_constant(name, value):
133    if isinstance(value, int):
134        parts = ["int ", name, " = ", "%d"%value, ";"]
135    elif isinstance(value, float):
136        parts = ["double ", name, " = ", "%.15g"%value, ";"]
137    else:
138        # extend constant arrays to a multiple of 4; not sure if this
139        # is necessary, but some OpenCL targets broke if the number
140        # of parameters in the parameter table was not a multiple of 4,
141        # so do it for all constant arrays to be safe.
142        if len(value)%4 != 0:
143            value = list(value) + [0.]*(4 - len(value)%4)
144        elements = ["%.15g"%v for v in value]
145        parts = ["double ", name, "[]", " = ",
146                 "{\n   ", ", ".join(elements), "\n};"]
147    return "".join(parts)
148
149
150# Modified from the following:
151#
152#    http://code.activestate.com/recipes/578272-topological-sort/
153#    Copyright (C) 2012 Sam Denton
154#    License: MIT
155def ordered_dag(dag):
156    # type: (Dict[T, Set[T]]) -> Iterator[T]
157    dag = dag.copy()
158
159    # make leaves depend on the empty set
160    leaves = reduce(set.union, dag.values()) - set(dag.keys())
161    dag.update({node: set() for node in leaves})
162    while True:
163        leaves = set(node for node, links in dag.items() if not links)
164        if not leaves:
165            break
166        for node in leaves:
167            yield node
168        dag = {node: (links-leaves)
169               for node, links in dag.items() if node not in leaves}
170    if dag:
171        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
172                            % ", ".join(str(node) for node in dag.keys()))
Note: See TracBrowser for help on using the repository browser.