source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ a1d1b90

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a1d1b90 was a1d1b90, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

updated wrapperGenerator

  • Property mode set to 100644
File size: 19.9 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        ## Name of struct for the c object
72        self.structName = None
73        ## Dictionary of parameters
74        self.params = {}
75        ## ModelCalculation module flag
76        self.modelCalcFlag = False
77        ## List of default parameters (text)
78        self.default_list = ""
79        ## Dictionary of units
80        self.details = ""
81        ## List of dispersed parameters
82        self.disp_params = []
83        #model description
84        self.description=''
85        # paramaters for fittable
86        self.fixed= []
87        # paramaters for non-fittable
88        self.non_fittable= []
89        ## parameters with orientation
90        self.orientation_params =[]
91       
92       
93    def __repr__(self):
94        """ Simple output for printing """
95       
96        rep  = "\n Python class: %s\n\n" % self.pythonClass
97        rep += "  struc name: %s\n\n" % self.structName
98        rep += "  params:     %s\n\n" % self.params
99        rep += "  description:    %s\n\n" % self.description
100        rep += "  Fittable parameters:     %s\n\n"% self.fixed
101        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
102        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
103        return rep
104       
105    def read(self):
106        """ Reads in the .h file to catch parameters of the wrapper """
107       
108        # Check if the file is there
109        if not os.path.isfile(self.file):
110            raise ValueError, "File %s is not a regular file" % self.file
111       
112        # Read file
113        f = open(self.file,'r')
114        buf = f.read()
115       
116        self.default_list = "List of default parameters:\n"
117        #lines = string.split(buf,'\n')
118        lines = buf.split('\n')
119        self.details  = "## Parameter details [units, min, max]\n"
120        self.details += "        self.details = {}\n"
121       
122        #open item in this case Fixed
123        text='text'
124        key2="<%s>"%text.lower()
125        # close an item in this case fixed
126        text='TexT'
127        key3="</%s>"%text.lower()
128       
129        ## Catch fixed parameters
130        key = "[FIXED]"
131        try:
132            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
133        except:
134           raise   
135        ## Catch non-fittable parameters parameters
136        key = "[NON_FITTABLE_PARAMS]"
137        try:
138            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
139        except:
140           raise   
141
142        ## Catch parameters with orientation
143        key = "[ORIENTATION_PARAMS]"   
144        try:
145            self.orientation_params = lineparser.readhelper( lines,key, 
146                                                    key2,key3, file= self.file)
147        except:
148           raise 
149        ## Catch Description
150        key = "[DESCRIPTION]"
151       
152        find_description = False
153        temp=""
154        for line in lines:
155            if line.count(key)>0 :
156               
157                try:
158                    find_description= True
159                    index = line.index(key)
160                    toks = line[index:].split("=",1 )
161                    temp=toks[1].lstrip().rstrip()
162                    text='text'
163                    key2="<%s>"%text.lower()
164                    if re.match(key2,temp)!=None:
165   
166                        toks2=temp.split(key2,1)
167                        self.description=toks2[1]
168                        text='text'
169                        key2="</%s>"%text.lower()
170                        if re.search(key2,toks2[1])!=None:
171                            temp=toks2[1].split(key2,1)
172                            self.description=temp[0]
173                            break
174                     
175                    else:
176                        self.description=temp
177                        break
178                except:
179                     raise ValueError, "Could not parse file %s" % self.file
180            elif find_description:
181                text='text'
182                key2="</%s>"%text.lower()
183                if re.search(key2,line)!=None:
184                    tok=line.split(key2,1)
185                    temp=tok[0].split("//",1)
186                    self.description+=tok[1].lstrip().rstrip()
187                    break
188                else:
189                    if re.search("//",line)!=None:
190                        temp=line.split("//",1)
191                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
192                       
193                    else:
194                        self.description+='\n\t\t'+line.lstrip().rstrip()
195                   
196               
197               
198        for line in lines:
199           
200            # Catch class name
201            key = "[PYTHONCLASS]"
202            if line.count(key)>0:
203                try:
204                    index = line.index(key)
205                    #toks = string.split( line[index:], "=" )
206                    toks = line[index:].split("=" )
207                    self.pythonClass = toks[1].lstrip().rstrip()
208                except:
209                    raise ValueError, "Could not parse file %s" % self.file
210               
211            # Catch struct name
212            # C++ class definition
213            if line.count("class")>0:
214                # We are entering a class definition
215                self.inStruct = True
216                self.foundCPP = True
217                class_name = line.replace('}','')
218                class_name = class_name.replace('class','')
219                self.inStruct = class_name.strip()
220           
221            if self.inStruct and line.count("}")>0:
222                # We are exiting a struct block
223                self.inStruct = False
224               
225            # Old-Style C struct definition
226            if line.count("typedef struct")>0:
227                # We are entering a struct block
228                self.inStruct = True
229           
230            if self.inStruct and line.count("}")>0:
231                # We are exiting a struct block
232                self.inStruct = False
233   
234                # Catch the name of the struct
235                index = line.index("}")
236                #toks = string.split(line[index+1:],";")
237                toks = line[index+1:].split(";")
238                # Catch pointer definition
239                #toks2 = string.split(toks[0],',')
240                toks2 = toks[0].split(',')
241                self.structName = toks2[0].lstrip().rstrip()
242           
243               
244           
245            # Catch struct content
246            key = "[DEFAULT]"
247            if self.inStruct and line.count(key)>0:
248                # Found a new parameter
249                try:
250                    index = line.index(key)
251                    toks = line[index:].split("=")
252                    toks2 = toks[2].split()
253                    val = float(toks2[0])
254                    self.params[toks[1]] = val
255                    #self.pythonClass = toks[1].lstrip().rstrip()
256                    units = ""
257                    if len(toks2) >= 2:
258                        units = toks2[1]
259                    self.default_list += "         %-15s = %s %s\n" % \
260                        (toks[1], val, units)
261                   
262                    # Check for min and max
263                    min = "None"
264                    max = "None"
265                    if len(toks2) == 4:
266                        min = toks2[2]
267                        max = toks2[3]
268                   
269                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
270                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
271                except:
272                    raise ValueError, "Could not parse input file %s \n  %s" % \
273                        (self.file, sys.exc_value)
274               
275               
276            # Catch need for numerical calculations
277            key = "CalcParameters calcPars"
278            if line.count(key)>0:
279                self.modelCalcFlag = True
280               
281            # Catch list of dispersed parameters
282            key = "[DISP_PARAMS]"
283            if line.count(key)>0:
284                try:
285                    index = line.index(key)
286                    toks = line[index:].split("=")
287                    list_str = toks[1].lstrip().rstrip()
288                    self.disp_params = list_str.split(',')
289                except:
290                    raise ValueError, "Could not parse file %s" % self.file
291               
292       
293               
294    def write_c_wrapper(self):
295        """ Writes the C file to create the python extension class
296            The file is written in C[PYTHONCLASS].c
297        """
298       
299        file = open("C"+self.pythonClass+'.cpp', 'w')
300        template = open("classTemplate.txt", 'r')
301       
302        tmp_buf = template.read()
303        #tmp_lines = string.split(tmp_buf,'\n')
304        tmp_lines = tmp_buf.split('\n')
305       
306        for tmp_line in tmp_lines:
307           
308            # Catch class name
309            newline = self.replaceToken(tmp_line, 
310                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
311            #Catch model description
312            #newline = self.replaceToken(tmp_line,
313            #                            "[DESCRIPTION]", self.description)
314            # Catch C model name
315            newline = self.replaceToken(newline, 
316                                        "[CMODEL]", self.pythonClass)
317           
318            # Catch class name
319            newline = self.replaceToken(newline, 
320                                        "[MODELSTRUCT]", self.structName)
321           
322            # Dictionary initialization
323            param_str = "// Initialize parameter dictionary\n"           
324            for par in self.params:
325                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
326                    (par, self.params[par])
327
328            param_str += "        // Initialize dispersion / averaging parameter dict\n"
329            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
330            param_str += "        PyObject * disp_dict;\n"
331            for par in self.disp_params:
332                par = par.strip()
333                param_str += "        disp_dict = PyDict_New();\n"
334                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
335                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
336               
337            # Initialize dispersion object dictionnary
338            param_str += "\n"
339           
340               
341            newline = self.replaceToken(newline,
342                                        "[INITDICTIONARY]", param_str)
343           
344            # Read dictionary
345            param_str = "    // Reader parameter dictionary\n"
346            for par in self.params:
347                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
348                    (par, par)
349                   
350            param_str += "    // Read in dispersion parameters\n"
351            param_str += "    PyObject* disp_dict;\n"
352            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
353            for par in self.disp_params:
354                par = par.strip()
355                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
356                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
357               
358            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
359               
360            # Name of .c file
361            #toks = string.split(self.file,'.')
362            basename = os.path.basename(self.file)
363            toks = basename.split('.')
364            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
365           
366            # Include file
367            basename = os.path.basename(self.file)
368            newline = self.replaceToken(newline, 
369                                        "[INCLUDE_FILE]", self.file) 
370            if self.foundCPP:
371                newline = self.replaceToken(newline, 
372                                            "[C_INCLUDE_FILE]", "") 
373                newline = self.replaceToken(newline, 
374                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
375            else: 
376                newline = self.replaceToken(newline, 
377                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
378                newline = self.replaceToken(newline, 
379                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
380               
381            # Numerical calcs dealloc
382            dealloc_str = "\n"
383            if self.modelCalcFlag:
384                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
385            newline = self.replaceToken(newline, 
386                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
387               
388            # Numerical calcs init
389            init_str = "\n"
390            if self.modelCalcFlag:
391                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
392            newline = self.replaceToken(newline, 
393                                        "[NUMERICAL_INIT]", init_str)     
394               
395            # Numerical calcs reset
396            reset_str = "\n"
397            if self.modelCalcFlag:
398                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
399            newline = self.replaceToken(newline, 
400                                        "[NUMERICAL_RESET]", reset_str)     
401               
402            # Setting dispsertion weights
403            set_weights = "    // Ugliness necessary to go from python to C\n"
404            set_weights = "    // TODO: refactor this\n"
405            for par in self.disp_params:
406                par = par.strip()
407                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
408                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
409                set_weights += "    } else"
410            newline = self.replaceToken(newline, 
411                                        "[SET_DISPERSION]", set_weights)     
412           
413            # Write new line to the wrapper .c file
414            file.write(newline+'\n')
415           
416           
417        file.close()
418       
419    def write_python_wrapper(self):
420        """ Writes the python file to create the python extension class
421            The file is written in ../[PYTHONCLASS].py
422        """
423       
424        file = open("../sans/models/"+self.pythonClass+'.py', 'w')
425        template = open("modelTemplate.txt", 'r')
426       
427        tmp_buf = template.read()
428        tmp_lines = tmp_buf.split('\n')
429       
430        for tmp_line in tmp_lines:
431           
432            # Catch class name
433            newline = self.replaceToken(tmp_line, 
434                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
435           
436            # Catch class name
437            newline = self.replaceToken(newline, 
438                                        "[PYTHONCLASS]", self.pythonClass)
439           
440            # Include file
441            newline = self.replaceToken(newline, 
442                                        "[INCLUDE_FILE]", self.file)   
443                   
444            # Include file
445            newline = self.replaceToken(newline, 
446                                        "[DEFAULT_LIST]", self.default_list)
447            # model description
448            newline = self.replaceToken(newline, 
449                                        "[DESCRIPTION]", self.description)
450            # Parameter details
451            newline = self.replaceToken(newline, 
452                                        "[PAR_DETAILS]", self.details)
453           
454            # fixed list  details
455            newline = self.replaceToken(newline, 
456                                        "[FIXED]",str(self.fixed))
457            # non-fittable list  details
458            newline = self.replaceToken(newline, 
459                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
460            ## parameters with orientation
461       
462            newline = self.replaceToken(newline, 
463                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
464           
465            # Write new line to the wrapper .c file
466            file.write(newline+'\n')
467               
468        file.close()
469       
470       
471    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
472        """ Replace a token in the template file
473            @param line: line of text to inspect
474            @param key: token to look for
475            @param value: string value to replace the token with
476            @return: new string value
477        """
478        lenkey = len(key)
479        newline = line
480       
481        while newline.count(key)>0:
482            index = newline.index(key)
483            newline = newline[:index]+value+newline[index+lenkey:]
484       
485        return newline
486       
487       
488# main
489if __name__ == '__main__':
490    if len(sys.argv)>1:
491        print "Will look for file %s" % sys.argv[1]
492    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
493        app = WrapperGenerator(sys.argv[1])
494    else:
495        app = WrapperGenerator("test.h")
496    app.read()
497    app.write_c_wrapper()
498    app.write_python_wrapper()
499    print app
500   
501# End of file       
Note: See TracBrowser for help on using the repository browser.