source: sasmodels/sasmodels/autoc.py @ 50b5464

Last change on this file since 50b5464 was 50b5464, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

rearrange autoc to allow py2c translation to do whole program analysis

  • Property mode set to 100644
File size: 6.0 KB
Line 
1"""
2Automatically translate python models to C
3"""
4from __future__ import print_function
5
6import ast
7import inspect
8from functools import reduce
9
10import numpy as np
11
12from . import codegen
13from . import special
14
15# pylint: disable=unused-import
16try:
17    from types import ModuleType
18    #from .modelinfo import ModelInfo  # circular import
19except ImportError:
20    pass
21# pylint: enable=unused-import
22
23DEPENDENCY = {
24    'core_shell_kernel': ['lib/core_shell.c'],
25    'fractal_sq': ['lib/fractal_sq.c'],
26    'gfn4': ['lib/gfn.c'],
27    'polevl': ['lib/polevl.c'],
28    'p1evl': ['lib/polevl.c'],
29    'sas_2J1x_x': ['lib/polevl.c', 'lib/sas_J1.c'],
30    'sas_3j1x_x': ['lib/sas_3j1x_x.c'],
31    'sas_erf': ['lib/polevl.c', 'lib/sas_erf.c'],
32    'sas_erfc': ['lib/polevl.c', 'lib/sas_erf.c'],
33    'sas_gamma': ['lib/sas_gamma.c'],
34    'sas_J0': ['lib/polevl.c', 'lib/sas_J0.c'],
35    'sas_J1': ['lib/polevl.c', 'lib/sas_J1.c'],
36    'sas_JN': ['lib/polevl.c', 'lib/sas_J0.c', 'lib/sas_J1.c', 'lib/sas_JN.c'],
37    'sas_Si': ['lib/Si.c'],
38}
39
40DEFINES = frozenset("M_PI M_PI_2 M_PI_4 M_SQRT1_2 M_E NAN INFINITY M_PI_180 M_4PI_3".split())
41
42def convert(info, module):
43    # type: ("ModelInfo", ModuleType) -> bool
44    """
45    convert Iq, Iqxy and form_volume to c
46    """
47    # Check if there is already C code
48    if info.source or info.c_code is not None:
49        return
50
51    public_methods = "Iq", "Iqxy", "form_volume"
52
53    tagged = [] # type: List[str]
54    translate = [] # type: List[Callable]
55    for function_name in public_methods:
56        function = getattr(info, function_name)
57        if callable(function):
58            if getattr(function, 'vectorized', None):
59                return  # Don't try to translate vectorized code
60            tagged.append(function_name)
61            translate.append((function_name, function))
62    if not translate:
63        # nothing to translate---maybe Iq, etc. are already C snippets?
64        return
65
66    libs = []  # type: List[str]
67    snippets = []  # type: List[str]
68    constants = {} # type: Dict[str, Any]
69    code = {}  # type: Dict[str, str]
70    depends = {}  # type: Dict[str, List[str]]
71    while translate:
72        function_name, function = translate.pop(0)
73        filename = function.__code__.co_filename
74        offset = function.__code__.co_firstlineno
75        refs = function.__code__.co_names
76        depends[function_name] = set(refs)
77        source = inspect.getsource(function)
78        for name in refs:
79            if name in tagged or name in DEFINES:
80                continue
81            tagged.append(name)
82            obj = getattr(module, name, None)
83            if obj is None:
84                pass # ignore unbound variables for now
85                #raise ValueError("global %s is not defined" % name)
86            elif callable(obj):
87                if getattr(special, name, None):
88                    # special symbol: look up depenencies
89                    libs.extend(DEPENDENCY.get(name, []))
90                else:
91                    # not special: add function to translate stack
92                    translate.append((name, obj))
93            elif isinstance(obj, float):
94                constants[name] = obj
95                snippets.append("const double %s = %.15g;"%(name, obj))
96            elif isinstance(obj, int):
97                constants[name] = obj
98                snippets.append("const int %s = %d;"%(name, obj))
99            elif isinstance(obj, (list, tuple, np.ndarray)):
100                constants[name] = obj
101                # extend constant arrays to a multiple of 4; not sure if this
102                # is necessary, but some OpenCL targets broke if the number
103                # of parameters in the parameter table was not a multiple of 4,
104                # so do it for all constant arrays to be safe.
105                if len(obj)%4 != 0:
106                    obj = list(obj) + [0.]*(4-len(obj))
107                vals = ", ".join("%.15g"%v for v in obj)
108                snippets.append("const double %s[] = {%s};" %(name, vals))
109            elif isinstance(obj, special.Gauss):
110                constants["GAUSS_N"] = obj.n
111                constants["GAUSS_Z"] = obj.z
112                constants["GAUSS_W"] = obj.w
113                libs.append('lib/gauss%d.c'%obj.n)
114                source = (source.replace(name+'.n', 'GAUSS_N')
115                          .replace(name+'.z', 'GAUSS_Z')
116                          .replace(name+'.w', 'GAUSS_W'))
117            else:
118                raise TypeError("Could not convert global %s of type %s"
119                                % (name, str(type(obj))))
120
121        # add (possibly modified) source to set of functions to compile
122        code[function_name] = (source, filename, offset)
123
124    # remove duplicates from the dependecy list
125    unique_libs = []
126    for filename in libs:
127        if filename not in unique_libs:
128            unique_libs.append(filename)
129
130    # translate source
131    functions = codegen.translate(
132        [code[name] for name in ordered_dag(depends) if name in code],
133        constants)
134
135    # update model info
136    info.source = unique_libs
137    info.c_code = "\n".join(snippets) +  functions
138    info.Iq = info.Iqxy = info.form_volume = None
139
140    print("source", info.source)
141    print(info.c_code)
142
143    raise RuntimeError("not yet converted...")
144
145
146# Modified from the following:
147#
148#    http://code.activestate.com/recipes/578272-topological-sort/
149#    Copyright (C) 2012 Sam Denton
150#    License: MIT
151def ordered_dag(dag):
152    # type: (Dict[T, Set[T]]) -> Iterator[T]
153    dag = dag.copy()
154
155    # make leaves depend on the empty set
156    leaves = reduce(set.union, dag.values()) - set(dag.keys())
157    dag.update({node: set() for node in leaves})
158    while True:
159        leaves = set(node for node, links in dag.items() if not links)
160        if not leaves:
161            break
162        for node in leaves:
163            yield node
164        dag = {node: (links-leaves)
165               for node, links in dag.items() if node not in leaves}
166    if dag:
167        raise ValueError("Cyclic dependes exists amongst these items:\n%s"
168                            % ", ".join(str(node) for node in dag.keys()))
Note: See TracBrowser for help on using the repository browser.