source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ 7608cb5

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 7608cb5 was 95986b5, checked in by Gervaise Alina <gervyh@…>, 16 years ago

sans modelsgenerated with wrappergenerator c_models

  • Property mode set to 100644
File size: 17.1 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79        ## List of dispersed parameters
80        self.disp_params = []
81        #model description
82        self.description=''
83       
84    def __repr__(self):
85        """ Simple output for printing """
86       
87        rep  = "Python class: %s\n" % self.pythonClass
88        rep += "  struc name: %s\n" % self.structName
89        rep += "  params:     %s\n" % self.params
90        rep += "  description:     %s\n" % self.description
91        return rep
92       
93    def read(self):
94        """ Reads in the .h file to catch parameters of the wrapper """
95       
96        # Check if the file is there
97        if not os.path.isfile(self.file):
98            raise ValueError, "File %s is not a regular file" % self.file
99       
100        # Read file
101        f = open(self.file,'r')
102        buf = f.read()
103       
104        self.default_list = "List of default parameters:\n"
105        #lines = string.split(buf,'\n')
106        lines = buf.split('\n')
107        self.details  = "## Parameter details [units, min, max]\n"
108        self.details += "        self.details = {}\n"
109         # Catch Description
110        key = "[DESCRIPTION]"
111        find_description= 0
112        temp=""
113        for line in lines:
114            if line.count(key)>0 :
115               
116                try:
117                    find_description= 1
118                    index = line.index(key)
119                    toks = line[index:].split("=",1 )
120                    temp=toks[1].lstrip().rstrip()
121                    text='text'
122                    key2="<%s>"%text.lower()
123                    if re.match(key2,temp)!=None:
124                        #index2 = line.index(key2)
125                        #temp = temp[index2:]
126                        toks2=temp.split(key2,1)
127                        self.description=toks2[1]
128                        text='text'
129                        key2="</%s>"%text.lower()
130                        if re.search(key2,toks2[1])!=None:
131                            temp=toks2[1].split(key2,1)
132                            self.description=temp[0]
133                            break
134                        #print self.description
135                    else:
136                        self.description=temp
137                        break
138                except:
139                     raise
140                     raise ValueError, "Could not parse file %s" % self.file
141            elif find_description==1:
142                text='text'
143                key2="</%s>"%text.lower()
144                #print "second line",line,key2,re.search(key2,line)
145                if re.search(key2,line)!=None:
146                    tok=line.split(key2,1)
147                    temp=tok[0].split("//",1)
148                    self.description+=tok[1].lstrip().rstrip()
149                    break
150                else:
151                    #if re.search("*",line)!=None:
152                    #    temp=line.split("*",1)
153                    #    self.description+='\n'+temp[1].lstrip().rstrip()
154                    if re.search("//",line)!=None:
155                        temp=line.split("//",1)
156                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
157                       
158                    else:
159                        self.description+='\n\t\t'+line.lstrip().rstrip()
160                   
161               
162               
163        for line in lines:
164           
165            # Catch class name
166            key = "[PYTHONCLASS]"
167            if line.count(key)>0:
168                try:
169                    index = line.index(key)
170                    #toks = string.split( line[index:], "=" )
171                    toks = line[index:].split("=" )
172                    self.pythonClass = toks[1].lstrip().rstrip()
173                except:
174                    raise ValueError, "Could not parse file %s" % self.file
175               
176            # Catch struct name
177            if line.count("typedef struct")>0:
178                # We are entering a struct block
179                self.inStruct = True
180           
181            if self.inStruct and line.count("}")>0:
182                # We are exiting a struct block
183                self.inStruct = False
184   
185                # Catch the name of the struct
186                index = line.index("}")
187                #toks = string.split(line[index+1:],";")
188                toks = line[index+1:].split(";")
189                # Catch pointer definition
190                #toks2 = string.split(toks[0],',')
191                toks2 = toks[0].split(',')
192                self.structName = toks2[0].lstrip().rstrip()
193           
194               
195           
196            # Catch struct content
197            key = "[DEFAULT]"
198            if self.inStruct and line.count(key)>0:
199                # Found a new parameter
200                try:
201                    index = line.index(key)
202                    toks = line[index:].split("=")
203                    toks2 = toks[2].split()
204                    val = float(toks2[0])
205                    self.params[toks[1]] = val
206                    #self.pythonClass = toks[1].lstrip().rstrip()
207                    units = ""
208                    if len(toks2) >= 2:
209                        units = toks2[1]
210                    self.default_list += "         %-15s = %s %s\n" % \
211                        (toks[1], val, units)
212                   
213                    # Check for min and max
214                    min = "None"
215                    max = "None"
216                    if len(toks2) == 4:
217                        min = toks2[2]
218                        max = toks2[3]
219                   
220                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
221                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
222                except:
223                    raise ValueError, "Could not parse input file %s \n  %s" % \
224                        (self.file, sys.exc_value)
225               
226               
227            # Catch need for numerical calculations
228            key = "CalcParameters calcPars"
229            if line.count(key)>0:
230                self.modelCalcFlag = True
231               
232            # Catch list of dispersed parameters
233            key = "[DISP_PARAMS]"
234            if line.count(key)>0:
235                try:
236                    index = line.index(key)
237                    toks = line[index:].split("=")
238                    list_str = toks[1].lstrip().rstrip()
239                    self.disp_params = list_str.split(',')
240                except:
241                    raise ValueError, "Could not parse file %s" % self.file
242               
243       
244               
245    def write_c_wrapper(self):
246        """ Writes the C file to create the python extension class
247            The file is written in C[PYTHONCLASS].c
248        """
249       
250        file = open("C"+self.pythonClass+'.cpp', 'w')
251        template = open("classTemplate.txt", 'r')
252       
253        tmp_buf = template.read()
254        #tmp_lines = string.split(tmp_buf,'\n')
255        tmp_lines = tmp_buf.split('\n')
256       
257        for tmp_line in tmp_lines:
258           
259            # Catch class name
260            newline = self.replaceToken(tmp_line, 
261                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
262            #Catch model description
263            #newline = self.replaceToken(tmp_line,
264            #                            "[DESCRIPTION]", self.description)
265            # Catch C model name
266            newline = self.replaceToken(newline, 
267                                        "[CMODEL]", self.pythonClass)
268           
269            # Catch class name
270            newline = self.replaceToken(newline, 
271                                        "[MODELSTRUCT]", self.structName)
272           
273            # Dictionary initialization
274            param_str = "// Initialize parameter dictionary\n"           
275            for par in self.params:
276                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
277                    (par, self.params[par])
278
279            param_str += "        // Initialize dispersion / averaging parameter dict\n"
280            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
281            param_str += "        PyObject * disp_dict;\n"
282            for par in self.disp_params:
283                par = par.strip()
284                param_str += "        disp_dict = PyDict_New();\n"
285                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
286                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
287               
288            # Initialize dispersion object dictionnary
289            param_str += "\n"
290           
291               
292            newline = self.replaceToken(newline,
293                                        "[INITDICTIONARY]", param_str)
294           
295            # Read dictionary
296            param_str = "    // Reader parameter dictionary\n"
297            for par in self.params:
298                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
299                    (par, par)
300                   
301            param_str += "    // Read in dispersion parameters\n"
302            param_str += "    PyObject* disp_dict;\n"
303            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
304            for par in self.disp_params:
305                par = par.strip()
306                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
307                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
308               
309            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
310               
311            # Name of .c file
312            #toks = string.split(self.file,'.')
313            basename = os.path.basename(self.file)
314            toks = basename.split('.')
315            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
316           
317            # Include file
318            basename = os.path.basename(self.file)
319            newline = self.replaceToken(newline, 
320                                        "[INCLUDE_FILE]", basename)           
321               
322            # Numerical calcs dealloc
323            dealloc_str = "\n"
324            if self.modelCalcFlag:
325                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
326            newline = self.replaceToken(newline, 
327                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
328               
329            # Numerical calcs init
330            init_str = "\n"
331            if self.modelCalcFlag:
332                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
333            newline = self.replaceToken(newline, 
334                                        "[NUMERICAL_INIT]", init_str)     
335               
336            # Numerical calcs reset
337            reset_str = "\n"
338            if self.modelCalcFlag:
339                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
340            newline = self.replaceToken(newline, 
341                                        "[NUMERICAL_RESET]", reset_str)     
342               
343            # Setting dispsertion weights
344            set_weights = "    // Ugliness necessary to go from python to C\n"
345            set_weights = "    // TODO: refactor this\n"
346            for par in self.disp_params:
347                par = par.strip()
348                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
349                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
350                set_weights += "    } else"
351            newline = self.replaceToken(newline, 
352                                        "[SET_DISPERSION]", set_weights)     
353           
354            # Write new line to the wrapper .c file
355            file.write(newline+'\n')
356           
357           
358        file.close()
359       
360    def write_python_wrapper(self):
361        """ Writes the python file to create the python extension class
362            The file is written in ../[PYTHONCLASS].py
363        """
364       
365        file = open("../"+self.pythonClass+'.py', 'w')
366        template = open("modelTemplate.txt", 'r')
367       
368        tmp_buf = template.read()
369        tmp_lines = tmp_buf.split('\n')
370       
371        for tmp_line in tmp_lines:
372           
373            # Catch class name
374            newline = self.replaceToken(tmp_line, 
375                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
376           
377            # Catch class name
378            newline = self.replaceToken(newline, 
379                                        "[PYTHONCLASS]", self.pythonClass)
380           
381            # Include file
382            newline = self.replaceToken(newline, 
383                                        "[INCLUDE_FILE]", self.file)   
384                   
385            # Include file
386            newline = self.replaceToken(newline, 
387                                        "[DEFAULT_LIST]", self.default_list)
388            # model description
389            newline = self.replaceToken(newline, 
390                                        "[DESCRIPTION]", self.description)
391            # Parameter details
392            newline = self.replaceToken(newline, 
393                                        "[PAR_DETAILS]", self.details)
394           
395
396            # Write new line to the wrapper .c file
397            file.write(newline+'\n')
398               
399        file.close()
400       
401       
402    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
403        """ Replace a token in the template file
404            @param line: line of text to inspect
405            @param key: token to look for
406            @param value: string value to replace the token with
407            @return: new string value
408        """
409        lenkey = len(key)
410        newline = line
411        while newline.count(key)>0:
412            index = newline.index(key)
413            newline = newline[:index]+value+newline[index+lenkey:]
414        return newline
415       
416       
417# main
418if __name__ == '__main__':
419    if len(sys.argv)>1:
420        print "Will look for file %s" % sys.argv[1]
421    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
422        app = WrapperGenerator(sys.argv[1])
423    else:
424        app = WrapperGenerator("test.h")
425    app.read()
426    app.write_c_wrapper()
427    app.write_python_wrapper()
428    print app
429   
430# End of file       
Note: See TracBrowser for help on using the repository browser.