source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ 9316609

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 9316609 was 9316609, checked in by Gervaise Alina <gervyh@…>, 16 years ago

added description to model

  • Property mode set to 100644
File size: 17.0 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79        ## List of dispersed parameters
80        self.disp_params = []
81        #model description
82        self.description=''
83       
84    def __repr__(self):
85        """ Simple output for printing """
86       
87        rep  = "Python class: %s\n" % self.pythonClass
88        rep += "  struc name: %s\n" % self.structName
89        rep += "  params:     %s\n" % self.params
90        rep += "  description:     %s\n" % self.description
91        return rep
92       
93    def read(self):
94        """ Reads in the .h file to catch parameters of the wrapper """
95       
96        # Check if the file is there
97        if not os.path.isfile(self.file):
98            raise ValueError, "File %s is not a regular file" % self.file
99       
100        # Read file
101        f = open(self.file,'r')
102        buf = f.read()
103       
104        self.default_list = "List of default parameters:\n"
105        #lines = string.split(buf,'\n')
106        lines = buf.split('\n')
107        self.details  = "## Parameter details [units, min, max]\n"
108        self.details += "        self.details = {}\n"
109         # Catch Description
110        key = "[DESCRIPTION]"
111        find_description= 0
112        temp=""
113        for line in lines:
114            if line.count(key)>0 :
115               
116                try:
117                    find_description= 1
118                    index = line.index(key)
119                    toks = line[index:].split("=",1 )
120                    temp=toks[1].lstrip().rstrip()
121                    text='text'
122                    key2="<%s>"%text.lower()
123                    if re.match(key2,temp)!=None:
124                        #index2 = line.index(key2)
125                        #temp = temp[index2:]
126                        toks2=temp.split(key2,1)
127                        self.description=toks2[1]
128                        text='text'
129                        key2="</%s>"%text.lower()
130                        if re.search(key2,toks2[1])!=None:
131                            temp=toks2[1].split(key2,1)
132                            self.description=temp[0]
133                            break
134                        #print self.description
135                    else:
136                        self.description=temp
137                        break
138                except:
139                     raise ValueError, "Could not parse file %s" % self.file
140            elif find_description==1:
141                text='text'
142                key2="</%s>"%text.lower()
143                #print "second line",line,key2,re.search(key2,line)
144                if re.search(key2,line)!=None:
145                    tok=line.split(key2,1)
146                    temp=tok[0].split("//",1)
147                    self.description+=tok[1].lstrip().rstrip()
148                    break
149                else:
150                    #if re.search("*",line)!=None:
151                    #    temp=line.split("*",1)
152                    #    self.description+='\n'+temp[1].lstrip().rstrip()
153                    if re.search("//",line)!=None:
154                        temp=line.split("//",1)
155                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
156                       
157                    else:
158                        self.description+='\n\t\t'+line.lstrip().rstrip()
159                   
160               
161               
162        for line in lines:
163           
164            # Catch class name
165            key = "[PYTHONCLASS]"
166            if line.count(key)>0:
167                try:
168                    index = line.index(key)
169                    #toks = string.split( line[index:], "=" )
170                    toks = line[index:].split("=" )
171                    self.pythonClass = toks[1].lstrip().rstrip()
172                except:
173                    raise ValueError, "Could not parse file %s" % self.file
174               
175            # Catch struct name
176            if line.count("typedef struct")>0:
177                # We are entering a struct block
178                self.inStruct = True
179           
180            if self.inStruct and line.count("}")>0:
181                # We are exiting a struct block
182                self.inStruct = False
183   
184                # Catch the name of the struct
185                index = line.index("}")
186                #toks = string.split(line[index+1:],";")
187                toks = line[index+1:].split(";")
188                # Catch pointer definition
189                #toks2 = string.split(toks[0],',')
190                toks2 = toks[0].split(',')
191                self.structName = toks2[0].lstrip().rstrip()
192           
193               
194           
195            # Catch struct content
196            key = "[DEFAULT]"
197            if self.inStruct and line.count(key)>0:
198                # Found a new parameter
199                try:
200                    index = line.index(key)
201                    toks = line[index:].split("=")
202                    toks2 = toks[2].split()
203                    val = float(toks2[0])
204                    self.params[toks[1]] = val
205                    #self.pythonClass = toks[1].lstrip().rstrip()
206                    units = ""
207                    if len(toks2) >= 2:
208                        units = toks2[1]
209                    self.default_list += "         %-15s = %s %s\n" % \
210                        (toks[1], val, units)
211                   
212                    # Check for min and max
213                    min = "None"
214                    max = "None"
215                    if len(toks2) == 4:
216                        min = toks2[2]
217                        max = toks2[3]
218                   
219                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
220                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
221                except:
222                    raise ValueError, "Could not parse input file %s \n  %s" % \
223                        (self.file, sys.exc_value)
224               
225               
226            # Catch need for numerical calculations
227            key = "CalcParameters calcPars"
228            if line.count(key)>0:
229                self.modelCalcFlag = True
230               
231            # Catch list of dispersed parameters
232            key = "[DISP_PARAMS]"
233            if line.count(key)>0:
234                try:
235                    index = line.index(key)
236                    toks = line[index:].split("=")
237                    list_str = toks[1].lstrip().rstrip()
238                    self.disp_params = list_str.split(',')
239                except:
240                    raise ValueError, "Could not parse file %s" % self.file
241               
242               
243               
244    def write_c_wrapper(self):
245        """ Writes the C file to create the python extension class
246            The file is written in C[PYTHONCLASS].c
247        """
248       
249        file = open("C"+self.pythonClass+'.cpp', 'w')
250        template = open("classTemplate.txt", 'r')
251       
252        tmp_buf = template.read()
253        #tmp_lines = string.split(tmp_buf,'\n')
254        tmp_lines = tmp_buf.split('\n')
255       
256        for tmp_line in tmp_lines:
257           
258            # Catch class name
259            newline = self.replaceToken(tmp_line, 
260                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
261            #Catch model description
262            newline = self.replaceToken(tmp_line, 
263                                        "[DESCRIPTION]", self.description)
264            # Catch C model name
265            newline = self.replaceToken(newline, 
266                                        "[CMODEL]", self.pythonClass)
267           
268            # Catch class name
269            newline = self.replaceToken(newline, 
270                                        "[MODELSTRUCT]", self.structName)
271           
272            # Dictionary initialization
273            param_str = "// Initialize parameter dictionary\n"           
274            for par in self.params:
275                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
276                    (par, self.params[par])
277
278            param_str += "        // Initialize dispersion / averaging parameter dict\n"
279            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
280            param_str += "        PyObject * disp_dict;\n"
281            for par in self.disp_params:
282                par = par.strip()
283                param_str += "        disp_dict = PyDict_New();\n"
284                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
285                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
286               
287            # Initialize dispersion object dictionnary
288            param_str += "\n"
289           
290               
291            newline = self.replaceToken(newline,
292                                        "[INITDICTIONARY]", param_str)
293           
294            # Read dictionary
295            param_str = "    // Reader parameter dictionary\n"
296            for par in self.params:
297                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
298                    (par, par)
299                   
300            param_str += "    // Read in dispersion parameters\n"
301            param_str += "    PyObject* disp_dict;\n"
302            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
303            for par in self.disp_params:
304                par = par.strip()
305                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
306                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
307               
308            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
309               
310            # Name of .c file
311            #toks = string.split(self.file,'.')
312            basename = os.path.basename(self.file)
313            toks = basename.split('.')
314            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
315           
316            # Include file
317            basename = os.path.basename(self.file)
318            newline = self.replaceToken(newline, 
319                                        "[INCLUDE_FILE]", basename)           
320               
321            # Numerical calcs dealloc
322            dealloc_str = "\n"
323            if self.modelCalcFlag:
324                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
325            newline = self.replaceToken(newline, 
326                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
327               
328            # Numerical calcs init
329            init_str = "\n"
330            if self.modelCalcFlag:
331                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
332            newline = self.replaceToken(newline, 
333                                        "[NUMERICAL_INIT]", init_str)     
334               
335            # Numerical calcs reset
336            reset_str = "\n"
337            if self.modelCalcFlag:
338                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
339            newline = self.replaceToken(newline, 
340                                        "[NUMERICAL_RESET]", reset_str)     
341               
342            # Setting dispsertion weights
343            set_weights = "    // Ugliness necessary to go from python to C\n"
344            set_weights = "    // TODO: refactor this\n"
345            for par in self.disp_params:
346                par = par.strip()
347                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
348                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
349                set_weights += "    } else"
350            newline = self.replaceToken(newline, 
351                                        "[SET_DISPERSION]", set_weights)     
352           
353           
354            # Write new line to the wrapper .c file
355            file.write(newline+'\n')
356           
357           
358        file.close()
359       
360    def write_python_wrapper(self):
361        """ Writes the python file to create the python extension class
362            The file is written in ../[PYTHONCLASS].py
363        """
364       
365        file = open("../"+self.pythonClass+'.py', 'w')
366        template = open("modelTemplate.txt", 'r')
367       
368        tmp_buf = template.read()
369        tmp_lines = tmp_buf.split('\n')
370       
371        for tmp_line in tmp_lines:
372           
373            # Catch class name
374            newline = self.replaceToken(tmp_line, 
375                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
376           
377            # Catch class name
378            newline = self.replaceToken(newline, 
379                                        "[PYTHONCLASS]", self.pythonClass)
380           
381            # Include file
382            newline = self.replaceToken(newline, 
383                                        "[INCLUDE_FILE]", self.file)   
384                   
385            # Include file
386            newline = self.replaceToken(newline, 
387                                        "[DEFAULT_LIST]", self.default_list)
388
389            # Parameter details
390            newline = self.replaceToken(newline, 
391                                        "[PAR_DETAILS]", self.details)
392            # Parameter details
393            newline = self.replaceToken(newline, 
394                                        "[DESCRIPTION]", self.description)
395
396            # Write new line to the wrapper .c file
397            file.write(newline+'\n')
398               
399        file.close()
400       
401       
402    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
403        """ Replace a token in the template file
404            @param line: line of text to inspect
405            @param key: token to look for
406            @param value: string value to replace the token with
407            @return: new string value
408        """
409        lenkey = len(key)
410        newline = line
411        while newline.count(key)>0:
412            index = newline.index(key)
413            newline = newline[:index]+value+newline[index+lenkey:]
414        return newline
415       
416       
417# main
418if __name__ == '__main__':
419    if len(sys.argv)>1:
420        print "Will look for file %s" % sys.argv[1]
421        app = WrapperGenerator(sys.argv[1])
422    else:
423        app = WrapperGenerator("test.h")
424    app.read()
425    app.write_c_wrapper()
426    app.write_python_wrapper()
427    print app
428   
429# End of file       
Note: See TracBrowser for help on using the repository browser.