source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ af03ddd

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since af03ddd was af03ddd, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

Model C extension update

  • Property mode set to 100644
File size: 14.4 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys
6
7class WrapperGenerator:
8    """ Python wrapper generator for C models
9   
10        The developer must provide a header file describing
11        the new model.
12       
13        To provide the name of the Python class to be
14        generated, the .h file must contain the following
15        string in the comments:
16       
17        // [PYTHONCLASS] = my_model
18       
19        where my_model must be replaced by the name of the
20        class that you want to import from sans.models.
21        (example: [PYTHONCLASS] = MyModel
22          will create a class MyModel in sans.models.MyModel.
23          It will also create a class CMyModel in
24          sans_extension.c_models.)
25         
26        Also in comments, each parameter of the params
27        dictionary must be declared with a default value
28        in the following way:
29       
30        // [DEFAULT]=param_name=default_value
31       
32        (example:
33            //  [DEFAULT]=radius=20.0
34        )
35         
36        See cylinder.h for an example.
37       
38       
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41   
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44       
45            /// 1D scattering function
46            double cylinder_analytical_1D(CylinderParameters *pars, double q);
47           
48            /// 2D scattering function
49            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50           
51        The .c file implements those functions.
52       
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56   
57    def __init__(self, filename):
58        """ Initialization """
59       
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62       
63        # Info read from file
64       
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        ## Name of struct for the c object
70        self.structName = None
71        ## Dictionary of parameters
72        self.params = {}
73        ## ModelCalculation module flag
74        self.modelCalcFlag = False
75        ## List of default parameters (text)
76        self.default_list = ""
77        ## Dictionary of units
78        self.details = ""
79        ## List of dispersed parameters
80        self.disp_params = []
81       
82    def __repr__(self):
83        """ Simple output for printing """
84       
85        rep  = "Python class: %s\n" % self.pythonClass
86        rep += "  struc name: %s\n" % self.structName
87        rep += "  params:     %s\n" % self.params
88        return rep
89       
90    def read(self):
91        """ Reads in the .h file to catch parameters of the wrapper """
92       
93        # Check if the file is there
94        if not os.path.isfile(self.file):
95            raise ValueError, "File %s is not a regular file" % self.file
96       
97        # Read file
98        f = open(self.file,'r')
99        buf = f.read()
100       
101        self.default_list = "List of default parameters:\n"
102        #lines = string.split(buf,'\n')
103        lines = buf.split('\n')
104        self.details  = "## Parameter details [units, min, max]\n"
105        self.details += "        self.details = {}\n"
106        for line in lines:
107           
108            # Catch class name
109            key = "[PYTHONCLASS]"
110            if line.count(key)>0:
111                try:
112                    index = line.index(key)
113                    #toks = string.split( line[index:], "=" )
114                    toks = line[index:].split("=" )
115                    self.pythonClass = toks[1].lstrip().rstrip()
116                except:
117                    raise ValueError, "Could not parse file %s" % self.file
118               
119            # Catch struct name
120            if line.count("typedef struct")>0:
121                # We are entering a struct block
122                self.inStruct = True
123           
124            if self.inStruct and line.count("}")>0:
125                # We are exiting a struct block
126                self.inStruct = False
127   
128                # Catch the name of the struct
129                index = line.index("}")
130                #toks = string.split(line[index+1:],";")
131                toks = line[index+1:].split(";")
132                # Catch pointer definition
133                #toks2 = string.split(toks[0],',')
134                toks2 = toks[0].split(',')
135                self.structName = toks2[0].lstrip().rstrip()
136               
137            # Catch struct content
138            key = "[DEFAULT]"
139            if self.inStruct and line.count(key)>0:
140                # Found a new parameter
141                try:
142                    index = line.index(key)
143                    toks = line[index:].split("=")
144                    toks2 = toks[2].split()
145                    val = float(toks2[0])
146                    self.params[toks[1]] = val
147                    #self.pythonClass = toks[1].lstrip().rstrip()
148                    units = ""
149                    if len(toks2) >= 2:
150                        units = toks2[1]
151                    self.default_list += "         %-15s = %s %s\n" % \
152                        (toks[1], val, units)
153                   
154                    # Check for min and max
155                    min = "None"
156                    max = "None"
157                    if len(toks2) == 4:
158                        min = toks2[2]
159                        max = toks2[3]
160                   
161                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
162                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
163                except:
164                    raise ValueError, "Could not parse input file %s \n  %s" % \
165                        (self.file, sys.exc_value)
166               
167               
168            # Catch need for numerical calculations
169            key = "CalcParameters calcPars"
170            if line.count(key)>0:
171                self.modelCalcFlag = True
172               
173            # Catch list of dispersed parameters
174            key = "[DISP_PARAMS]"
175            if line.count(key)>0:
176                try:
177                    index = line.index(key)
178                    toks = line[index:].split("=")
179                    list_str = toks[1].lstrip().rstrip()
180                    self.disp_params = list_str.split(',')
181                except:
182                    raise ValueError, "Could not parse file %s" % self.file
183               
184               
185               
186    def write_c_wrapper(self):
187        """ Writes the C file to create the python extension class
188            The file is written in C[PYTHONCLASS].c
189        """
190       
191        file = open("C"+self.pythonClass+'.cpp', 'w')
192        template = open("classTemplate.txt", 'r')
193       
194        tmp_buf = template.read()
195        #tmp_lines = string.split(tmp_buf,'\n')
196        tmp_lines = tmp_buf.split('\n')
197       
198        for tmp_line in tmp_lines:
199           
200            # Catch class name
201            newline = self.replaceToken(tmp_line, 
202                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
203           
204            # Catch C model name
205            newline = self.replaceToken(newline, 
206                                        "[CMODEL]", self.pythonClass)
207           
208            # Catch class name
209            newline = self.replaceToken(newline, 
210                                        "[MODELSTRUCT]", self.structName)
211           
212            # Dictionary initialization
213            param_str = "// Initialize parameter dictionary\n"           
214            for par in self.params:
215                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\", self->model->%s()));\n" % \
216                    (par, par)
217
218            param_str += "        // Initialize dispersion / averaging parameter dict\n"
219            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
220            param_str += "        PyObject * disp_dict;\n"
221            for par in self.disp_params:
222                par = par.strip()
223                param_str += "        disp_dict = PyDict_New();\n"
224                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
225                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
226               
227            # Initialize dispersion object dictionnary
228            param_str += "\n"
229           
230               
231            newline = self.replaceToken(newline,
232                                        "[INITDICTIONARY]", param_str)
233           
234            # Read dictionary
235            param_str = "    // Reader parameter dictionary\n"
236            for par in self.params:
237                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
238                    (par, par)
239                   
240            param_str += "    // Read in dispersion parameters\n"
241            param_str += "    PyObject* disp_dict;\n"
242            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
243            for par in self.disp_params:
244                par = par.strip()
245                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
246                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
247               
248            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
249               
250            # Name of .c file
251            #toks = string.split(self.file,'.')
252            basename = os.path.basename(self.file)
253            toks = basename.split('.')
254            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
255           
256            # Include file
257            basename = os.path.basename(self.file)
258            newline = self.replaceToken(newline, 
259                                        "[INCLUDE_FILE]", basename)           
260               
261            # Numerical calcs dealloc
262            dealloc_str = "\n"
263            if self.modelCalcFlag:
264                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
265            newline = self.replaceToken(newline, 
266                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
267               
268            # Numerical calcs init
269            init_str = "\n"
270            if self.modelCalcFlag:
271                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
272            newline = self.replaceToken(newline, 
273                                        "[NUMERICAL_INIT]", init_str)     
274               
275            # Numerical calcs reset
276            reset_str = "\n"
277            if self.modelCalcFlag:
278                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
279            newline = self.replaceToken(newline, 
280                                        "[NUMERICAL_RESET]", reset_str)     
281               
282            # Setting dispsertion weights
283            set_weights = "    // Ugliness necessary to go from python to C\n"
284            set_weights = "    // TODO: refactor this\n"
285            for par in self.disp_params:
286                par = par.strip()
287                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
288                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
289                set_weights += "    } else"
290            newline = self.replaceToken(newline, 
291                                        "[SET_DISPERSION]", set_weights)     
292           
293           
294            # Write new line to the wrapper .c file
295            file.write(newline+'\n')
296           
297           
298        file.close()
299       
300    def write_python_wrapper(self):
301        """ Writes the python file to create the python extension class
302            The file is written in ../[PYTHONCLASS].py
303        """
304       
305        file = open("../"+self.pythonClass+'.py', 'w')
306        template = open("modelTemplate.txt", 'r')
307       
308        tmp_buf = template.read()
309        tmp_lines = tmp_buf.split('\n')
310       
311        for tmp_line in tmp_lines:
312           
313            # Catch class name
314            newline = self.replaceToken(tmp_line, 
315                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
316           
317            # Catch class name
318            newline = self.replaceToken(newline, 
319                                        "[PYTHONCLASS]", self.pythonClass)
320           
321            # Include file
322            newline = self.replaceToken(newline, 
323                                        "[INCLUDE_FILE]", self.file)   
324                   
325            # Include file
326            newline = self.replaceToken(newline, 
327                                        "[DEFAULT_LIST]", self.default_list)
328
329            # Parameter details
330            newline = self.replaceToken(newline, 
331                                        "[PAR_DETAILS]", self.details)
332
333            # Write new line to the wrapper .c file
334            file.write(newline+'\n')
335               
336        file.close()
337       
338       
339    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
340        """ Replace a token in the template file
341            @param line: line of text to inspect
342            @param key: token to look for
343            @param value: string value to replace the token with
344            @return: new string value
345        """
346        lenkey = len(key)
347        newline = line
348        while newline.count(key)>0:
349            index = newline.index(key)
350            newline = newline[:index]+value+newline[index+lenkey:]
351        return newline
352       
353       
354# main
355if __name__ == '__main__':
356    if len(sys.argv)>1:
357        print "Will look for file %s" % sys.argv[1]
358        app = WrapperGenerator(sys.argv[1])
359    else:
360        app = WrapperGenerator("test.h")
361    app.read()
362    app.write_c_wrapper()
363    app.write_python_wrapper()
364    print app
365   
366# End of file       
Note: See TracBrowser for help on using the repository browser.