source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ 4e2f6ef8

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4e2f6ef8 was 4e2f6ef8, checked in by Gervaise Alina <gervyh@…>, 16 years ago

add description

  • Property mode set to 100644
File size: 14.6 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79        ## List of dispersed parameters
80        self.disp_params = []
81        #model description
82        self.description=''
83       
84    def __repr__(self):
85        """ Simple output for printing """
86       
87        rep  = "Python class: %s\n" % self.pythonClass
88        rep += "  struc name: %s\n" % self.structName
89        rep += "  params:     %s\n" % self.params
90        return rep
91       
92    def read(self):
93        """ Reads in the .h file to catch parameters of the wrapper """
94       
95        # Check if the file is there
96        if not os.path.isfile(self.file):
97            raise ValueError, "File %s is not a regular file" % self.file
98       
99        # Read file
100        f = open(self.file,'r')
101        buf = f.read()
102       
103        self.default_list = "List of default parameters:\n"
104        #lines = string.split(buf,'\n')
105        lines = buf.split('\n')
106        self.details  = "## Parameter details [units, min, max]\n"
107        self.details += "        self.details = {}\n"
108        for line in lines:
109           
110            # Catch class name
111            key = "[PYTHONCLASS]"
112            if line.count(key)>0:
113                try:
114                    index = line.index(key)
115                    #toks = string.split( line[index:], "=" )
116                    toks = line[index:].split("=" )
117                    self.pythonClass = toks[1].lstrip().rstrip()
118                except:
119                    raise ValueError, "Could not parse file %s" % self.file
120               
121            # Catch struct name
122            if line.count("typedef struct")>0:
123                # We are entering a struct block
124                self.inStruct = True
125           
126            if self.inStruct and line.count("}")>0:
127                # We are exiting a struct block
128                self.inStruct = False
129   
130                # Catch the name of the struct
131                index = line.index("}")
132                #toks = string.split(line[index+1:],";")
133                toks = line[index+1:].split(";")
134                # Catch pointer definition
135                #toks2 = string.split(toks[0],',')
136                toks2 = toks[0].split(',')
137                self.structName = toks2[0].lstrip().rstrip()
138               
139            # Catch struct content
140            key = "[DEFAULT]"
141            if self.inStruct and line.count(key)>0:
142                # Found a new parameter
143                try:
144                    index = line.index(key)
145                    toks = line[index:].split("=")
146                    toks2 = toks[2].split()
147                    val = float(toks2[0])
148                    self.params[toks[1]] = val
149                    #self.pythonClass = toks[1].lstrip().rstrip()
150                    units = ""
151                    if len(toks2) >= 2:
152                        units = toks2[1]
153                    self.default_list += "         %-15s = %s %s\n" % \
154                        (toks[1], val, units)
155                   
156                    # Check for min and max
157                    min = "None"
158                    max = "None"
159                    if len(toks2) == 4:
160                        min = toks2[2]
161                        max = toks2[3]
162                   
163                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
164                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
165                except:
166                    raise ValueError, "Could not parse input file %s \n  %s" % \
167                        (self.file, sys.exc_value)
168               
169               
170            # Catch need for numerical calculations
171            key = "CalcParameters calcPars"
172            if line.count(key)>0:
173                self.modelCalcFlag = True
174               
175            # Catch list of dispersed parameters
176            key = "[DISP_PARAMS]"
177            if line.count(key)>0:
178                try:
179                    index = line.index(key)
180                    toks = line[index:].split("=")
181                    list_str = toks[1].lstrip().rstrip()
182                    self.disp_params = list_str.split(',')
183                except:
184                    raise ValueError, "Could not parse file %s" % self.file
185               
186               
187               
188    def write_c_wrapper(self):
189        """ Writes the C file to create the python extension class
190            The file is written in C[PYTHONCLASS].c
191        """
192       
193        file = open("C"+self.pythonClass+'.cpp', 'w')
194        template = open("classTemplate.txt", 'r')
195       
196        tmp_buf = template.read()
197        #tmp_lines = string.split(tmp_buf,'\n')
198        tmp_lines = tmp_buf.split('\n')
199       
200        for tmp_line in tmp_lines:
201           
202            # Catch class name
203            newline = self.replaceToken(tmp_line, 
204                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
205           
206            # Catch C model name
207            newline = self.replaceToken(newline, 
208                                        "[CMODEL]", self.pythonClass)
209           
210            # Catch class name
211            newline = self.replaceToken(newline, 
212                                        "[MODELSTRUCT]", self.structName)
213           
214            # Dictionary initialization
215            param_str = "// Initialize parameter dictionary\n"           
216            for par in self.params:
217                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
218                    (par, self.params[par])
219
220            param_str += "        // Initialize dispersion / averaging parameter dict\n"
221            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
222            param_str += "        PyObject * disp_dict;\n"
223            for par in self.disp_params:
224                par = par.strip()
225                param_str += "        disp_dict = PyDict_New();\n"
226                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
227                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
228               
229            # Initialize dispersion object dictionnary
230            param_str += "\n"
231           
232               
233            newline = self.replaceToken(newline,
234                                        "[INITDICTIONARY]", param_str)
235           
236            # Read dictionary
237            param_str = "    // Reader parameter dictionary\n"
238            for par in self.params:
239                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
240                    (par, par)
241                   
242            param_str += "    // Read in dispersion parameters\n"
243            param_str += "    PyObject* disp_dict;\n"
244            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
245            for par in self.disp_params:
246                par = par.strip()
247                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
248                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
249               
250            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
251               
252            # Name of .c file
253            #toks = string.split(self.file,'.')
254            basename = os.path.basename(self.file)
255            toks = basename.split('.')
256            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
257           
258            # Include file
259            basename = os.path.basename(self.file)
260            newline = self.replaceToken(newline, 
261                                        "[INCLUDE_FILE]", basename)           
262               
263            # Numerical calcs dealloc
264            dealloc_str = "\n"
265            if self.modelCalcFlag:
266                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
267            newline = self.replaceToken(newline, 
268                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
269               
270            # Numerical calcs init
271            init_str = "\n"
272            if self.modelCalcFlag:
273                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
274            newline = self.replaceToken(newline, 
275                                        "[NUMERICAL_INIT]", init_str)     
276               
277            # Numerical calcs reset
278            reset_str = "\n"
279            if self.modelCalcFlag:
280                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
281            newline = self.replaceToken(newline, 
282                                        "[NUMERICAL_RESET]", reset_str)     
283               
284            # Setting dispsertion weights
285            set_weights = "    // Ugliness necessary to go from python to C\n"
286            set_weights = "    // TODO: refactor this\n"
287            for par in self.disp_params:
288                par = par.strip()
289                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
290                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
291                set_weights += "    } else"
292            newline = self.replaceToken(newline, 
293                                        "[SET_DISPERSION]", set_weights)     
294           
295           
296            # Write new line to the wrapper .c file
297            file.write(newline+'\n')
298           
299           
300        file.close()
301       
302    def write_python_wrapper(self):
303        """ Writes the python file to create the python extension class
304            The file is written in ../[PYTHONCLASS].py
305        """
306       
307        file = open("../"+self.pythonClass+'.py', 'w')
308        template = open("modelTemplate.txt", 'r')
309       
310        tmp_buf = template.read()
311        tmp_lines = tmp_buf.split('\n')
312       
313        for tmp_line in tmp_lines:
314           
315            # Catch class name
316            newline = self.replaceToken(tmp_line, 
317                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
318           
319            # Catch class name
320            newline = self.replaceToken(newline, 
321                                        "[PYTHONCLASS]", self.pythonClass)
322           
323            # Include file
324            newline = self.replaceToken(newline, 
325                                        "[INCLUDE_FILE]", self.file)   
326                   
327            # Include file
328            newline = self.replaceToken(newline, 
329                                        "[DEFAULT_LIST]", self.default_list)
330
331            # Parameter details
332            newline = self.replaceToken(newline, 
333                                        "[PAR_DETAILS]", self.details)
334            # Parameter details
335            newline = self.replaceToken(newline, 
336                                        "[DESCRIPTION]", self.description)
337
338            # Write new line to the wrapper .c file
339            file.write(newline+'\n')
340               
341        file.close()
342       
343       
344    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
345        """ Replace a token in the template file
346            @param line: line of text to inspect
347            @param key: token to look for
348            @param value: string value to replace the token with
349            @return: new string value
350        """
351        lenkey = len(key)
352        newline = line
353        while newline.count(key)>0:
354            index = newline.index(key)
355            newline = newline[:index]+value+newline[index+lenkey:]
356        return newline
357       
358       
359# main
360if __name__ == '__main__':
361    if len(sys.argv)>1:
362        print "Will look for file %s" % sys.argv[1]
363        app = WrapperGenerator(sys.argv[1])
364    else:
365        app = WrapperGenerator("test.h")
366    app.read()
367    app.write_c_wrapper()
368    app.write_python_wrapper()
369    print app
370   
371# End of file       
Note: See TracBrowser for help on using the repository browser.