source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 7ffa8196

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 7ffa8196 was d62f422, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

updated WrapperGenerator? to deal with both pre- and post-refactor models.

  • Property mode set to 100644
File size: 19.6 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92       
93       
94    def __repr__(self):
95        """ Simple output for printing """
96       
97        rep  = "\n Python class: %s\n\n" % self.pythonClass
98        rep += "  struc name: %s\n\n" % self.structName
99        rep += "  params:     %s\n\n" % self.params
100        rep += "  description:    %s\n\n" % self.description
101        rep += "  Fittable parameters:     %s\n\n"% self.fixed
102        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
103        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
104        return rep
105       
106    def read(self):
107        """ Reads in the .h file to catch parameters of the wrapper """
108       
109        # Check if the file is there
110        if not os.path.isfile(self.file):
111            raise ValueError, "File %s is not a regular file" % self.file
112       
113        # Read file
114        f = open(self.file,'r')
115        buf = f.read()
116       
117        self.default_list = "List of default parameters:\n"
118        #lines = string.split(buf,'\n')
119        lines = buf.split('\n')
120        self.details  = "## Parameter details [units, min, max]\n"
121        self.details += "        self.details = {}\n"
122       
123        #open item in this case Fixed
124        text='text'
125        key2="<%s>"%text.lower()
126        # close an item in this case fixed
127        text='TexT'
128        key3="</%s>"%text.lower()
129       
130        ## Catch fixed parameters
131        key = "[FIXED]"
132        try:
133            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
134        except:
135           raise   
136        ## Catch non-fittable parameters parameters
137        key = "[NON_FITTABLE_PARAMS]"
138        try:
139            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
140        except:
141           raise   
142
143        ## Catch parameters with orientation
144        key = "[ORIENTATION_PARAMS]"   
145        try:
146            self.orientation_params = lineparser.readhelper( lines,key, 
147                                                    key2,key3, file= self.file)
148        except:
149           raise 
150        ## Catch Description
151        key = "[DESCRIPTION]"
152       
153        find_description = False
154        temp=""
155        for line in lines:
156            if line.count(key)>0 :
157               
158                try:
159                    find_description= True
160                    index = line.index(key)
161                    toks = line[index:].split("=",1 )
162                    temp=toks[1].lstrip().rstrip()
163                    text='text'
164                    key2="<%s>"%text.lower()
165                    if re.match(key2,temp)!=None:
166   
167                        toks2=temp.split(key2,1)
168                        self.description=toks2[1]
169                        text='text'
170                        key2="</%s>"%text.lower()
171                        if re.search(key2,toks2[1])!=None:
172                            temp=toks2[1].split(key2,1)
173                            self.description=temp[0]
174                            break
175                     
176                    else:
177                        self.description=temp
178                        break
179                except:
180                     raise ValueError, "Could not parse file %s" % self.file
181            elif find_description:
182                text='text'
183                key2="</%s>"%text.lower()
184                if re.search(key2,line)!=None:
185                    tok=line.split(key2,1)
186                    temp=tok[0].split("//",1)
187                    self.description+=tok[1].lstrip().rstrip()
188                    break
189                else:
190                    if re.search("//",line)!=None:
191                        temp=line.split("//",1)
192                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
193                       
194                    else:
195                        self.description+='\n\t\t'+line.lstrip().rstrip()
196                   
197               
198               
199        for line in lines:
200           
201            # Catch class name
202            key = "[PYTHONCLASS]"
203            if line.count(key)>0:
204                try:
205                    index = line.index(key)
206                    #toks = string.split( line[index:], "=" )
207                    toks = line[index:].split("=" )
208                    self.pythonClass = toks[1].lstrip().rstrip()
209                except:
210                    raise ValueError, "Could not parse file %s" % self.file
211               
212            # Catch struct name
213            # C++ class definition
214            if line.count("class")>0:
215                # We are entering a class definition
216                self.inParDefs = True
217                self.foundCPP = True
218               
219            # Old-Style C struct definition
220            if line.count("typedef struct")>0:
221                # We are entering a struct block
222                self.inParDefs = True
223                self.inStruct = True
224           
225            if self.inParDefs and line.count("}")>0:
226                # We are exiting a struct block
227                self.inParDefs = False
228               
229                if self.inStruct:
230                    self.inStruct = False
231                    # Catch the name of the struct
232                    index = line.index("}")
233                    toks = line[index+1:].split(";")
234                    # Catch pointer definition
235                    toks2 = toks[0].split(',')
236                    self.structName = toks2[0].lstrip().rstrip()
237           
238            # Catch struct content
239            key = "[DEFAULT]"
240            if self.inParDefs and line.count(key)>0:
241                # Found a new parameter
242                try:
243                    index = line.index(key)
244                    toks = line[index:].split("=")
245                    toks2 = toks[2].split()
246                    val = float(toks2[0])
247                    self.params[toks[1]] = val
248                    #self.pythonClass = toks[1].lstrip().rstrip()
249                    units = ""
250                    if len(toks2) >= 2:
251                        units = toks2[1]
252                    self.default_list += "         %-15s = %s %s\n" % \
253                        (toks[1], val, units)
254                   
255                    # Check for min and max
256                    min = "None"
257                    max = "None"
258                    if len(toks2) == 4:
259                        min = toks2[2]
260                        max = toks2[3]
261                   
262                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
263                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
264                except:
265                    raise ValueError, "Could not parse input file %s \n  %s" % \
266                        (self.file, sys.exc_value)
267               
268               
269            # Catch need for numerical calculations
270            key = "CalcParameters calcPars"
271            if line.count(key)>0:
272                self.modelCalcFlag = True
273               
274            # Catch list of dispersed parameters
275            key = "[DISP_PARAMS]"
276            if line.count(key)>0:
277                try:
278                    index = line.index(key)
279                    toks = line[index:].split("=")
280                    list_str = toks[1].lstrip().rstrip()
281                    self.disp_params = list_str.split(',')
282                except:
283                    raise ValueError, "Could not parse file %s" % self.file
284               
285       
286               
287    def write_c_wrapper(self):
288        """ Writes the C file to create the python extension class
289            The file is written in C[PYTHONCLASS].c
290        """
291       
292        file = open("C"+self.pythonClass+'.cpp', 'w')
293        template = open("classTemplate.txt", 'r')
294       
295        tmp_buf = template.read()
296        #tmp_lines = string.split(tmp_buf,'\n')
297        tmp_lines = tmp_buf.split('\n')
298       
299        for tmp_line in tmp_lines:
300           
301            # Catch class name
302            newline = self.replaceToken(tmp_line, 
303                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
304            #Catch model description
305            #newline = self.replaceToken(tmp_line,
306            #                            "[DESCRIPTION]", self.description)
307            # Catch C model name
308            newline = self.replaceToken(newline, 
309                                        "[CMODEL]", self.pythonClass)
310           
311            # Catch class name
312            newline = self.replaceToken(newline, 
313                                        "[MODELSTRUCT]", self.structName)
314           
315            # Dictionary initialization
316            param_str = "// Initialize parameter dictionary\n"           
317            for par in self.params:
318                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
319                    (par, self.params[par])
320
321            param_str += "        // Initialize dispersion / averaging parameter dict\n"
322            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
323            param_str += "        PyObject * disp_dict;\n"
324            for par in self.disp_params:
325                par = par.strip()
326                param_str += "        disp_dict = PyDict_New();\n"
327                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
328                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
329               
330            # Initialize dispersion object dictionnary
331            param_str += "\n"
332           
333               
334            newline = self.replaceToken(newline,
335                                        "[INITDICTIONARY]", param_str)
336           
337            # Read dictionary
338            param_str = "    // Reader parameter dictionary\n"
339            for par in self.params:
340                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
341                    (par, par)
342                   
343            param_str += "    // Read in dispersion parameters\n"
344            param_str += "    PyObject* disp_dict;\n"
345            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
346            for par in self.disp_params:
347                par = par.strip()
348                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
349                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
350               
351            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
352               
353            # Name of .c file
354            #toks = string.split(self.file,'.')
355            basename = os.path.basename(self.file)
356            toks = basename.split('.')
357            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
358           
359            # Include file
360            basename = os.path.basename(self.file)
361            newline = self.replaceToken(newline, 
362                                        "[INCLUDE_FILE]", self.file) 
363            if self.foundCPP:
364                newline = self.replaceToken(newline, 
365                                            "[C_INCLUDE_FILE]", "") 
366                newline = self.replaceToken(newline, 
367                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
368            else: 
369                newline = self.replaceToken(newline, 
370                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
371                newline = self.replaceToken(newline, 
372                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
373               
374            # Numerical calcs dealloc
375            dealloc_str = "\n"
376            if self.modelCalcFlag:
377                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
378            newline = self.replaceToken(newline, 
379                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
380               
381            # Numerical calcs init
382            init_str = "\n"
383            if self.modelCalcFlag:
384                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
385            newline = self.replaceToken(newline, 
386                                        "[NUMERICAL_INIT]", init_str)     
387               
388            # Numerical calcs reset
389            reset_str = "\n"
390            if self.modelCalcFlag:
391                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
392            newline = self.replaceToken(newline, 
393                                        "[NUMERICAL_RESET]", reset_str)     
394               
395            # Setting dispsertion weights
396            set_weights = "    // Ugliness necessary to go from python to C\n"
397            set_weights = "    // TODO: refactor this\n"
398            for par in self.disp_params:
399                par = par.strip()
400                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
401                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
402                set_weights += "    } else"
403            newline = self.replaceToken(newline, 
404                                        "[SET_DISPERSION]", set_weights)     
405           
406            # Write new line to the wrapper .c file
407            file.write(newline+'\n')
408           
409           
410        file.close()
411       
412    def write_python_wrapper(self):
413        """ Writes the python file to create the python extension class
414            The file is written in ../[PYTHONCLASS].py
415        """
416       
417        file = open("../sans/models/"+self.pythonClass+'.py', 'w')
418        template = open("modelTemplate.txt", 'r')
419       
420        tmp_buf = template.read()
421        tmp_lines = tmp_buf.split('\n')
422       
423        for tmp_line in tmp_lines:
424           
425            # Catch class name
426            newline = self.replaceToken(tmp_line, 
427                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
428           
429            # Catch class name
430            newline = self.replaceToken(newline, 
431                                        "[PYTHONCLASS]", self.pythonClass)
432           
433            # Include file
434            newline = self.replaceToken(newline, 
435                                        "[INCLUDE_FILE]", self.file)   
436                   
437            # Include file
438            newline = self.replaceToken(newline, 
439                                        "[DEFAULT_LIST]", self.default_list)
440            # model description
441            newline = self.replaceToken(newline, 
442                                        "[DESCRIPTION]", self.description)
443            # Parameter details
444            newline = self.replaceToken(newline, 
445                                        "[PAR_DETAILS]", self.details)
446           
447            # fixed list  details
448            newline = self.replaceToken(newline, 
449                                        "[FIXED]",str(self.fixed))
450            # non-fittable list  details
451            newline = self.replaceToken(newline, 
452                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
453            ## parameters with orientation
454       
455            newline = self.replaceToken(newline, 
456                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
457           
458            # Write new line to the wrapper .c file
459            file.write(newline+'\n')
460               
461        file.close()
462       
463       
464    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
465        """ Replace a token in the template file
466            @param line: line of text to inspect
467            @param key: token to look for
468            @param value: string value to replace the token with
469            @return: new string value
470        """
471        lenkey = len(key)
472        newline = line
473       
474        while newline.count(key)>0:
475            index = newline.index(key)
476            newline = newline[:index]+value+newline[index+lenkey:]
477       
478        return newline
479       
480       
481# main
482if __name__ == '__main__':
483    if len(sys.argv)>1:
484        print "Will look for file %s" % sys.argv[1]
485    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
486        app = WrapperGenerator(sys.argv[1])
487    else:
488        app = WrapperGenerator("test.h")
489    app.read()
490    app.write_c_wrapper()
491    app.write_python_wrapper()
492    print app
493   
494# End of file       
Note: See TracBrowser for help on using the repository browser.