source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ 4d3acb6

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4d3acb6 was 96672c0, checked in by Gervaise Alina <gervyh@…>, 16 years ago

code for description modified

  • Property mode set to 100644
File size: 15.8 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79        ## List of dispersed parameters
80        self.disp_params = []
81        #model description
82        self.description=''
83       
84    def __repr__(self):
85        """ Simple output for printing """
86       
87        rep  = "Python class: %s\n" % self.pythonClass
88        rep += "  struc name: %s\n" % self.structName
89        rep += "  params:     %s\n" % self.params
90        rep += "  description:     %s\n" % self.description
91        return rep
92       
93    def read(self):
94        """ Reads in the .h file to catch parameters of the wrapper """
95       
96        # Check if the file is there
97        if not os.path.isfile(self.file):
98            raise ValueError, "File %s is not a regular file" % self.file
99       
100        # Read file
101        f = open(self.file,'r')
102        buf = f.read()
103       
104        self.default_list = "List of default parameters:\n"
105        #lines = string.split(buf,'\n')
106        lines = buf.split('\n')
107        self.details  = "## Parameter details [units, min, max]\n"
108        self.details += "        self.details = {}\n"
109         # Catch Description
110        key = "[DESCRIPTION]"
111        find_description= False
112        temp=""
113        for line in lines:
114            if line.count(key)>0 :
115                find_description= True
116                try:
117                    index = line.index(key)
118                    toks = line[index:].split("=",1 )
119                    temp=toks[1].lstrip().rstrip()
120                    text='text'
121                    key2="<%s>"%text.lower
122                    if re.match(key2,temp)!=None:
123                        index2 = line.index(key2)
124                        temp=line[index:]
125                    else:
126                        self.description=temp
127                except:
128                     raise ValueError, "Could not parse file %s" % self.file
129                if find_description:
130                    text='text'
131                    key2="</%s>"%text.lower
132                    if re.match(key2,temp)!=None:
133                        index2 = line.index(key2)
134                        temp=line[:index]
135                    temp+=line
136        self.description= temp
137               
138        for line in lines:
139           
140            # Catch class name
141            key = "[PYTHONCLASS]"
142            if line.count(key)>0:
143                try:
144                    index = line.index(key)
145                    #toks = string.split( line[index:], "=" )
146                    toks = line[index:].split("=" )
147                    self.pythonClass = toks[1].lstrip().rstrip()
148                except:
149                    raise ValueError, "Could not parse file %s" % self.file
150               
151            # Catch struct name
152            if line.count("typedef struct")>0:
153                # We are entering a struct block
154                self.inStruct = True
155           
156            if self.inStruct and line.count("}")>0:
157                # We are exiting a struct block
158                self.inStruct = False
159   
160                # Catch the name of the struct
161                index = line.index("}")
162                #toks = string.split(line[index+1:],";")
163                toks = line[index+1:].split(";")
164                # Catch pointer definition
165                #toks2 = string.split(toks[0],',')
166                toks2 = toks[0].split(',')
167                self.structName = toks2[0].lstrip().rstrip()
168           
169               
170           
171            # Catch struct content
172            key = "[DEFAULT]"
173            if self.inStruct and line.count(key)>0:
174                # Found a new parameter
175                try:
176                    index = line.index(key)
177                    toks = line[index:].split("=")
178                    toks2 = toks[2].split()
179                    val = float(toks2[0])
180                    self.params[toks[1]] = val
181                    #self.pythonClass = toks[1].lstrip().rstrip()
182                    units = ""
183                    if len(toks2) >= 2:
184                        units = toks2[1]
185                    self.default_list += "         %-15s = %s %s\n" % \
186                        (toks[1], val, units)
187                   
188                    # Check for min and max
189                    min = "None"
190                    max = "None"
191                    if len(toks2) == 4:
192                        min = toks2[2]
193                        max = toks2[3]
194                   
195                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
196                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
197                except:
198                    raise ValueError, "Could not parse input file %s \n  %s" % \
199                        (self.file, sys.exc_value)
200               
201               
202            # Catch need for numerical calculations
203            key = "CalcParameters calcPars"
204            if line.count(key)>0:
205                self.modelCalcFlag = True
206               
207            # Catch list of dispersed parameters
208            key = "[DISP_PARAMS]"
209            if line.count(key)>0:
210                try:
211                    index = line.index(key)
212                    toks = line[index:].split("=")
213                    list_str = toks[1].lstrip().rstrip()
214                    self.disp_params = list_str.split(',')
215                except:
216                    raise ValueError, "Could not parse file %s" % self.file
217               
218               
219               
220    def write_c_wrapper(self):
221        """ Writes the C file to create the python extension class
222            The file is written in C[PYTHONCLASS].c
223        """
224       
225        file = open("C"+self.pythonClass+'.cpp', 'w')
226        template = open("classTemplate.txt", 'r')
227       
228        tmp_buf = template.read()
229        #tmp_lines = string.split(tmp_buf,'\n')
230        tmp_lines = tmp_buf.split('\n')
231       
232        for tmp_line in tmp_lines:
233           
234            # Catch class name
235            newline = self.replaceToken(tmp_line, 
236                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
237           
238            # Catch C model name
239            newline = self.replaceToken(newline, 
240                                        "[CMODEL]", self.pythonClass)
241           
242            # Catch class name
243            newline = self.replaceToken(newline, 
244                                        "[MODELSTRUCT]", self.structName)
245           
246            # Dictionary initialization
247            param_str = "// Initialize parameter dictionary\n"           
248            for par in self.params:
249                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
250                    (par, self.params[par])
251
252            param_str += "        // Initialize dispersion / averaging parameter dict\n"
253            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
254            param_str += "        PyObject * disp_dict;\n"
255            for par in self.disp_params:
256                par = par.strip()
257                param_str += "        disp_dict = PyDict_New();\n"
258                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
259                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
260               
261            # Initialize dispersion object dictionnary
262            param_str += "\n"
263           
264               
265            newline = self.replaceToken(newline,
266                                        "[INITDICTIONARY]", param_str)
267           
268            # Read dictionary
269            param_str = "    // Reader parameter dictionary\n"
270            for par in self.params:
271                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
272                    (par, par)
273                   
274            param_str += "    // Read in dispersion parameters\n"
275            param_str += "    PyObject* disp_dict;\n"
276            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
277            for par in self.disp_params:
278                par = par.strip()
279                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
280                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
281               
282            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
283               
284            # Name of .c file
285            #toks = string.split(self.file,'.')
286            basename = os.path.basename(self.file)
287            toks = basename.split('.')
288            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
289           
290            # Include file
291            basename = os.path.basename(self.file)
292            newline = self.replaceToken(newline, 
293                                        "[INCLUDE_FILE]", basename)           
294               
295            # Numerical calcs dealloc
296            dealloc_str = "\n"
297            if self.modelCalcFlag:
298                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
299            newline = self.replaceToken(newline, 
300                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
301               
302            # Numerical calcs init
303            init_str = "\n"
304            if self.modelCalcFlag:
305                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
306            newline = self.replaceToken(newline, 
307                                        "[NUMERICAL_INIT]", init_str)     
308               
309            # Numerical calcs reset
310            reset_str = "\n"
311            if self.modelCalcFlag:
312                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
313            newline = self.replaceToken(newline, 
314                                        "[NUMERICAL_RESET]", reset_str)     
315               
316            # Setting dispsertion weights
317            set_weights = "    // Ugliness necessary to go from python to C\n"
318            set_weights = "    // TODO: refactor this\n"
319            for par in self.disp_params:
320                par = par.strip()
321                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
322                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
323                set_weights += "    } else"
324            newline = self.replaceToken(newline, 
325                                        "[SET_DISPERSION]", set_weights)     
326           
327           
328            # Write new line to the wrapper .c file
329            file.write(newline+'\n')
330           
331           
332        file.close()
333       
334    def write_python_wrapper(self):
335        """ Writes the python file to create the python extension class
336            The file is written in ../[PYTHONCLASS].py
337        """
338       
339        file = open("../"+self.pythonClass+'.py', 'w')
340        template = open("modelTemplate.txt", 'r')
341       
342        tmp_buf = template.read()
343        tmp_lines = tmp_buf.split('\n')
344       
345        for tmp_line in tmp_lines:
346           
347            # Catch class name
348            newline = self.replaceToken(tmp_line, 
349                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
350           
351            # Catch class name
352            newline = self.replaceToken(newline, 
353                                        "[PYTHONCLASS]", self.pythonClass)
354           
355            # Include file
356            newline = self.replaceToken(newline, 
357                                        "[INCLUDE_FILE]", self.file)   
358                   
359            # Include file
360            newline = self.replaceToken(newline, 
361                                        "[DEFAULT_LIST]", self.default_list)
362
363            # Parameter details
364            newline = self.replaceToken(newline, 
365                                        "[PAR_DETAILS]", self.details)
366            # Parameter details
367            newline = self.replaceToken(newline, 
368                                        "[DESCRIPTION]", self.description)
369
370            # Write new line to the wrapper .c file
371            file.write(newline+'\n')
372               
373        file.close()
374       
375       
376    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
377        """ Replace a token in the template file
378            @param line: line of text to inspect
379            @param key: token to look for
380            @param value: string value to replace the token with
381            @return: new string value
382        """
383        lenkey = len(key)
384        newline = line
385        while newline.count(key)>0:
386            index = newline.index(key)
387            newline = newline[:index]+value+newline[index+lenkey:]
388        return newline
389       
390       
391# main
392if __name__ == '__main__':
393    if len(sys.argv)>1:
394        print "Will look for file %s" % sys.argv[1]
395        app = WrapperGenerator(sys.argv[1])
396    else:
397        app = WrapperGenerator("test.h")
398    app.read()
399    app.write_c_wrapper()
400    app.write_python_wrapper()
401    print app
402   
403# End of file       
Note: See TracBrowser for help on using the repository browser.