source: sasview/sansmodels/src/sans/models/c_extensions/WrapperGenerator.py @ e71440c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e71440c was ae3ce4e, checked in by Mathieu Doucet <doucetm@…>, 17 years ago

Moving sansmodels to trunk

  • Property mode set to 100644
File size: 11.6 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79       
80    def __repr__(self):
81        """ Simple output for printing """
82       
83        rep  = "Python class: %s\n" % self.pythonClass
84        rep += "  struc name: %s\n" % self.structName
85        rep += "  params:     %s\n" % self.params
86        return rep
87       
88    def read(self):
89        """ Reads in the .h file to catch parameters of the wrapper """
90       
91        # Check if the file is there
92        if not os.path.isfile(self.file):
93            raise ValueError, "File %s is not a regular file" % self.file
94       
95        # Read file
96        f = open(self.file,'r')
97        buf = f.read()
98       
99        self.default_list = "List of default parameters:\n"
100        #lines = string.split(buf,'\n')
101        lines = buf.split('\n')
102        self.details  = "## Parameter details [units, min, max]\n"
103        self.details += "        self.details = {}\n"
104        for line in lines:
105           
106            # Catch class name
107            key = "[PYTHONCLASS]"
108            if line.count(key)>0:
109                try:
110                    index = line.index(key)
111                    #toks = string.split( line[index:], "=" )
112                    toks = line[index:].split("=" )
113                    self.pythonClass = toks[1].lstrip().rstrip()
114                except:
115                    raise ValueError, "Could not parse file %s" % self.file
116               
117            # Catch struct name
118            if line.count("typedef struct")>0:
119                # We are entering a struct block
120                self.inStruct = True
121           
122            if self.inStruct and line.count("}")>0:
123                # We are exiting a struct block
124                self.inStruct = False
125   
126                # Catch the name of the struct
127                index = line.index("}")
128                #toks = string.split(line[index+1:],";")
129                toks = line[index+1:].split(";")
130                # Catch pointer definition
131                #toks2 = string.split(toks[0],',')
132                toks2 = toks[0].split(',')
133                self.structName = toks2[0].lstrip().rstrip()
134               
135            # Catch struct content
136            key = "[DEFAULT]"
137            if self.inStruct and line.count(key)>0:
138                # Found a new parameter
139                try:
140                    index = line.index(key)
141                    toks = line[index:].split("=")
142                    toks2 = toks[2].split()
143                    val = float(toks2[0])
144                    self.params[toks[1]] = val
145                    #self.pythonClass = toks[1].lstrip().rstrip()
146                    units = ""
147                    if len(toks2) >= 2:
148                        units = toks2[1]
149                    self.default_list += "         %-15s = %s %s\n" % \
150                        (toks[1], val, units)
151                   
152                    # Check for min and max
153                    min = "None"
154                    max = "None"
155                    if len(toks2) == 4:
156                        min = toks2[2]
157                        max = toks2[3]
158                   
159                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
160                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
161                except:
162                    raise ValueError, "Could not parse input file %s \n  %s" % \
163                        (self.file, sys.exc_value)
164               
165               
166            # Catch need for numerical calculations
167            key = "CalcParameters calcPars"
168            if line.count(key)>0:
169                self.modelCalcFlag = True
170               
171               
172               
173    def write_c_wrapper(self):
174        """ Writes the C file to create the python extension class
175            The file is written in C[PYTHONCLASS].c
176        """
177       
178        file = open("C"+self.pythonClass+'.c', 'w')
179        template = open("classTemplate.txt", 'r')
180       
181        tmp_buf = template.read()
182        #tmp_lines = string.split(tmp_buf,'\n')
183        tmp_lines = tmp_buf.split('\n')
184       
185        for tmp_line in tmp_lines:
186           
187            # Catch class name
188            newline = self.replaceToken(tmp_line, 
189                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
190           
191            # Catch class name
192            newline = self.replaceToken(newline, 
193                                        "[MODELSTRUCT]", self.structName)
194           
195            # Dictionary initialization
196            param_str = "// Initialize parameter dictionary\n"
197            for par in self.params:
198                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
199                    (par, self.params[par])
200               
201            newline = self.replaceToken(newline,
202                                        "[INITDICTIONARY]", param_str)
203           
204            # Read dictionary
205            param_str = "// Reader parameter dictionary\n"
206            for par in self.params:
207                param_str += "    self->model_pars.%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
208                    (par, par)
209               
210            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
211               
212            # Name of .c file
213            #toks = string.split(self.file,'.')
214            toks = self.file.split('.')
215            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
216           
217            # Include file
218            newline = self.replaceToken(newline, 
219                                        "[INCLUDE_FILE]", self.file)           
220               
221            # Numerical calcs dealloc
222            dealloc_str = "\n"
223            if self.modelCalcFlag:
224                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
225            newline = self.replaceToken(newline, 
226                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
227               
228            # Numerical calcs init
229            init_str = "\n"
230            if self.modelCalcFlag:
231                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
232            newline = self.replaceToken(newline, 
233                                        "[NUMERICAL_INIT]", init_str)     
234               
235            # Numerical calcs reset
236            reset_str = "\n"
237            if self.modelCalcFlag:
238                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
239            newline = self.replaceToken(newline, 
240                                        "[NUMERICAL_RESET]", reset_str)     
241               
242            # Write new line to the wrapper .c file
243            file.write(newline+'\n')
244           
245           
246        file.close()
247       
248    def write_python_wrapper(self):
249        """ Writes the python file to create the python extension class
250            The file is written in ../[PYTHONCLASS].py
251        """
252       
253        file = open("../"+self.pythonClass+'.py', 'w')
254        template = open("modelTemplate.txt", 'r')
255       
256        tmp_buf = template.read()
257        tmp_lines = tmp_buf.split('\n')
258       
259        for tmp_line in tmp_lines:
260           
261            # Catch class name
262            newline = self.replaceToken(tmp_line, 
263                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
264           
265            # Catch class name
266            newline = self.replaceToken(newline, 
267                                        "[PYTHONCLASS]", self.pythonClass)
268           
269            # Include file
270            newline = self.replaceToken(newline, 
271                                        "[INCLUDE_FILE]", self.file)   
272                   
273            # Include file
274            newline = self.replaceToken(newline, 
275                                        "[DEFAULT_LIST]", self.default_list)
276
277            # Parameter details
278            newline = self.replaceToken(newline, 
279                                        "[PAR_DETAILS]", self.details)
280
281            # Write new line to the wrapper .c file
282            file.write(newline+'\n')
283               
284        file.close()
285       
286       
287    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
288        """ Replace a token in the template file
289            @param line: line of text to inspect
290            @param key: token to look for
291            @param value: string value to replace the token with
292            @return: new string value
293        """
294        lenkey = len(key)
295        newline = line
296        while newline.count(key)>0:
297            index = newline.index(key)
298            newline = newline[:index]+value+newline[index+lenkey:]
299        return newline
300       
301       
302# main
303if __name__ == '__main__':
304    if len(sys.argv)>1:
305        print "Will look for file %s" % sys.argv[1]
306        app = WrapperGenerator(sys.argv[1])
307    else:
308        app = WrapperGenerator("test.h")
309    app.read()
310    app.write_c_wrapper()
311    app.write_python_wrapper()
312    print app
313   
314# End of file       
Note: See TracBrowser for help on using the repository browser.