1 | """ |
---|
2 | Automatically translate python models to C |
---|
3 | """ |
---|
4 | from __future__ import print_function |
---|
5 | |
---|
6 | import ast |
---|
7 | import inspect |
---|
8 | from functools import reduce |
---|
9 | |
---|
10 | import numpy as np |
---|
11 | |
---|
12 | from . import py2c |
---|
13 | from . import special |
---|
14 | |
---|
15 | # pylint: disable=unused-import |
---|
16 | try: |
---|
17 | from types import ModuleType |
---|
18 | #from .modelinfo import ModelInfo # circular import |
---|
19 | except ImportError: |
---|
20 | pass |
---|
21 | # pylint: enable=unused-import |
---|
22 | |
---|
23 | DEPENDENCY = { |
---|
24 | 'core_shell_kernel': ['lib/core_shell.c'], |
---|
25 | 'fractal_sq': ['lib/fractal_sq.c'], |
---|
26 | 'gfn4': ['lib/gfn.c'], |
---|
27 | 'polevl': ['lib/polevl.c'], |
---|
28 | 'p1evl': ['lib/polevl.c'], |
---|
29 | 'sas_2J1x_x': ['lib/polevl.c', 'lib/sas_J1.c'], |
---|
30 | 'sas_3j1x_x': ['lib/sas_3j1x_x.c'], |
---|
31 | 'sas_erf': ['lib/polevl.c', 'lib/sas_erf.c'], |
---|
32 | 'sas_erfc': ['lib/polevl.c', 'lib/sas_erf.c'], |
---|
33 | 'sas_gamma': ['lib/sas_gamma.c'], |
---|
34 | 'sas_J0': ['lib/polevl.c', 'lib/sas_J0.c'], |
---|
35 | 'sas_J1': ['lib/polevl.c', 'lib/sas_J1.c'], |
---|
36 | 'sas_JN': ['lib/polevl.c', 'lib/sas_J0.c', 'lib/sas_J1.c', 'lib/sas_JN.c'], |
---|
37 | 'sas_Si': ['lib/Si.c'], |
---|
38 | } |
---|
39 | |
---|
40 | DEFINES = frozenset("M_PI M_PI_2 M_PI_4 M_SQRT1_2 M_E NAN INFINITY M_PI_180 M_4PI_3".split()) |
---|
41 | |
---|
42 | def convert(info, module): |
---|
43 | # type: ("ModelInfo", ModuleType) -> bool |
---|
44 | """ |
---|
45 | convert Iq, Iqxy and form_volume to c |
---|
46 | """ |
---|
47 | # Check if there is already C code |
---|
48 | if info.source or info.c_code is not None: |
---|
49 | return |
---|
50 | |
---|
51 | public_methods = "Iq", "Iqxy", "form_volume" |
---|
52 | |
---|
53 | tagged = [] # type: List[str] |
---|
54 | translate = [] # type: List[Callable] |
---|
55 | for function_name in public_methods: |
---|
56 | function = getattr(info, function_name) |
---|
57 | if callable(function): |
---|
58 | if getattr(function, 'vectorized', None): |
---|
59 | return # Don't try to translate vectorized code |
---|
60 | tagged.append(function_name) |
---|
61 | translate.append((function_name, function)) |
---|
62 | if not translate: |
---|
63 | # nothing to translate---maybe Iq, etc. are already C snippets? |
---|
64 | return |
---|
65 | |
---|
66 | libs = [] # type: List[str] |
---|
67 | snippets = [] # type: List[str] |
---|
68 | constants = {} # type: Dict[str, Any] |
---|
69 | code = {} # type: Dict[str, str] |
---|
70 | depends = {} # type: Dict[str, List[str]] |
---|
71 | while translate: |
---|
72 | function_name, function = translate.pop(0) |
---|
73 | filename = function.__code__.co_filename |
---|
74 | offset = function.__code__.co_firstlineno |
---|
75 | refs = function.__code__.co_names |
---|
76 | depends[function_name] = set(refs) |
---|
77 | source = inspect.getsource(function) |
---|
78 | for name in refs: |
---|
79 | if name in tagged or name in DEFINES: |
---|
80 | continue |
---|
81 | tagged.append(name) |
---|
82 | obj = getattr(module, name, None) |
---|
83 | if obj is None: |
---|
84 | pass # ignore unbound variables for now |
---|
85 | #raise ValueError("global %s is not defined" % name) |
---|
86 | elif callable(obj): |
---|
87 | if getattr(special, name, None): |
---|
88 | # special symbol: look up depenencies |
---|
89 | libs.extend(DEPENDENCY.get(name, [])) |
---|
90 | else: |
---|
91 | # not special: add function to translate stack |
---|
92 | translate.append((name, obj)) |
---|
93 | elif isinstance(obj, float): |
---|
94 | constants[name] = obj |
---|
95 | snippets.append("const double %s = %.15g;"%(name, obj)) |
---|
96 | elif isinstance(obj, int): |
---|
97 | constants[name] = obj |
---|
98 | snippets.append("const int %s = %d;"%(name, obj)) |
---|
99 | elif isinstance(obj, (list, tuple, np.ndarray)): |
---|
100 | constants[name] = obj |
---|
101 | # extend constant arrays to a multiple of 4; not sure if this |
---|
102 | # is necessary, but some OpenCL targets broke if the number |
---|
103 | # of parameters in the parameter table was not a multiple of 4, |
---|
104 | # so do it for all constant arrays to be safe. |
---|
105 | if len(obj)%4 != 0: |
---|
106 | obj = list(obj) + [0.]*(4-len(obj)) |
---|
107 | vals = ", ".join("%.15g"%v for v in obj) |
---|
108 | snippets.append("const double %s[] = {%s};" %(name, vals)) |
---|
109 | elif isinstance(obj, special.Gauss): |
---|
110 | constants["GAUSS_N"] = obj.n |
---|
111 | constants["GAUSS_Z"] = obj.z |
---|
112 | constants["GAUSS_W"] = obj.w |
---|
113 | libs.append('lib/gauss%d.c'%obj.n) |
---|
114 | source = (source.replace(name+'.n', 'GAUSS_N') |
---|
115 | .replace(name+'.z', 'GAUSS_Z') |
---|
116 | .replace(name+'.w', 'GAUSS_W')) |
---|
117 | else: |
---|
118 | raise TypeError("Could not convert global %s of type %s" |
---|
119 | % (name, str(type(obj)))) |
---|
120 | |
---|
121 | # add (possibly modified) source to set of functions to compile |
---|
122 | code[function_name] = (source, filename, offset) |
---|
123 | |
---|
124 | # remove duplicates from the dependecy list |
---|
125 | unique_libs = [] |
---|
126 | for filename in libs: |
---|
127 | if filename not in unique_libs: |
---|
128 | unique_libs.append(filename) |
---|
129 | |
---|
130 | # translate source |
---|
131 | functions = py2c.translate( |
---|
132 | [code[name] for name in ordered_dag(depends) if name in code], |
---|
133 | constants) |
---|
134 | snippets.clear() |
---|
135 | snippets.append(functions) |
---|
136 | #print("source", info.source) |
---|
137 | print("\n".join(snippets)) |
---|
138 | try: |
---|
139 | c_text = "\n".join(snippets) |
---|
140 | translated = open ("_autoc.c", "a+") |
---|
141 | translated.write (c_text) |
---|
142 | translated.close() |
---|
143 | except Exception as excp: |
---|
144 | strErr = "Error:\n" + str(excp.args) |
---|
145 | print(strErr) |
---|
146 | #return |
---|
147 | # raise RuntimeError("not yet converted...") |
---|
148 | |
---|
149 | # update model info |
---|
150 | info.source = unique_libs |
---|
151 | info.c_code = "\n".join(snippets) |
---|
152 | info.Iq = info.Iqxy = info.form_volume = None |
---|
153 | |
---|
154 | |
---|
155 | # Modified from the following: |
---|
156 | # |
---|
157 | # http://code.activestate.com/recipes/578272-topological-sort/ |
---|
158 | # Copyright (C) 2012 Sam Denton |
---|
159 | # License: MIT |
---|
160 | def ordered_dag(dag): |
---|
161 | # type: (Dict[T, Set[T]]) -> Iterator[T] |
---|
162 | dag = dag.copy() |
---|
163 | |
---|
164 | # make leaves depend on the empty set |
---|
165 | leaves = reduce(set.union, dag.values()) - set(dag.keys()) |
---|
166 | dag.update({node: set() for node in leaves}) |
---|
167 | while True: |
---|
168 | leaves = set(node for node, links in dag.items() if not links) |
---|
169 | if not leaves: |
---|
170 | break |
---|
171 | for node in leaves: |
---|
172 | yield node |
---|
173 | dag = {node: (links-leaves) |
---|
174 | for node, links in dag.items() if node not in leaves} |
---|
175 | if dag: |
---|
176 | raise ValueError("Cyclic dependes exists amongst these items:\n%s" |
---|
177 | % ", ".join(str(node) for node in dag.keys())) |
---|