source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ 4cc8285c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 4cc8285c was 35aface, checked in by Jae Cho <jhjcho@…>, 14 years ago

addede new models and attr. non_fittable

  • Property mode set to 100644
File size: 18.7 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        ## Name of struct for the c object
71        self.structName = None
72        ## Dictionary of parameters
73        self.params = {}
74        ## ModelCalculation module flag
75        self.modelCalcFlag = False
76        ## List of default parameters (text)
77        self.default_list = ""
78        ## Dictionary of units
79        self.details = ""
80        ## List of dispersed parameters
81        self.disp_params = []
82        #model description
83        self.description=''
84        # paramaters for fittable
85        self.fixed= []
86        # paramaters for non-fittable
87        self.non_fittable= []
88        ## parameters with orientation
89        self.orientation_params =[]
90       
91       
92    def __repr__(self):
93        """ Simple output for printing """
94       
95        rep  = "\n Python class: %s\n\n" % self.pythonClass
96        rep += "  struc name: %s\n\n" % self.structName
97        rep += "  params:     %s\n\n" % self.params
98        rep += "  description:    %s\n\n" % self.description
99        rep += "  Fittable parameters:     %s\n\n"% self.fixed
100        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
101        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
102        return rep
103       
104    def read(self):
105        """ Reads in the .h file to catch parameters of the wrapper """
106       
107        # Check if the file is there
108        if not os.path.isfile(self.file):
109            raise ValueError, "File %s is not a regular file" % self.file
110       
111        # Read file
112        f = open(self.file,'r')
113        buf = f.read()
114       
115        self.default_list = "List of default parameters:\n"
116        #lines = string.split(buf,'\n')
117        lines = buf.split('\n')
118        self.details  = "## Parameter details [units, min, max]\n"
119        self.details += "        self.details = {}\n"
120       
121        #open item in this case Fixed
122        text='text'
123        key2="<%s>"%text.lower()
124        # close an item in this case fixed
125        text='TexT'
126        key3="</%s>"%text.lower()
127       
128        ## Catch fixed parameters
129        key = "[FIXED]"
130        try:
131            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
132        except:
133           raise   
134        ## Catch non-fittable parameters parameters
135        key = "[NON_FITTABLE_PARAMS]"
136        try:
137            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
138        except:
139           raise   
140
141        ## Catch parameters with orientation
142        key = "[ORIENTATION_PARAMS]"   
143        try:
144            self.orientation_params = lineparser.readhelper( lines,key, 
145                                                    key2,key3, file= self.file)
146        except:
147           raise 
148        ## Catch Description
149        key = "[DESCRIPTION]"
150       
151        find_description = False
152        temp=""
153        for line in lines:
154            if line.count(key)>0 :
155               
156                try:
157                    find_description= True
158                    index = line.index(key)
159                    toks = line[index:].split("=",1 )
160                    temp=toks[1].lstrip().rstrip()
161                    text='text'
162                    key2="<%s>"%text.lower()
163                    if re.match(key2,temp)!=None:
164   
165                        toks2=temp.split(key2,1)
166                        self.description=toks2[1]
167                        text='text'
168                        key2="</%s>"%text.lower()
169                        if re.search(key2,toks2[1])!=None:
170                            temp=toks2[1].split(key2,1)
171                            self.description=temp[0]
172                            break
173                     
174                    else:
175                        self.description=temp
176                        break
177                except:
178                     raise ValueError, "Could not parse file %s" % self.file
179            elif find_description:
180                text='text'
181                key2="</%s>"%text.lower()
182                if re.search(key2,line)!=None:
183                    tok=line.split(key2,1)
184                    temp=tok[0].split("//",1)
185                    self.description+=tok[1].lstrip().rstrip()
186                    break
187                else:
188                    if re.search("//",line)!=None:
189                        temp=line.split("//",1)
190                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
191                       
192                    else:
193                        self.description+='\n\t\t'+line.lstrip().rstrip()
194                   
195               
196               
197        for line in lines:
198           
199            # Catch class name
200            key = "[PYTHONCLASS]"
201            if line.count(key)>0:
202                try:
203                    index = line.index(key)
204                    #toks = string.split( line[index:], "=" )
205                    toks = line[index:].split("=" )
206                    self.pythonClass = toks[1].lstrip().rstrip()
207                except:
208                    raise ValueError, "Could not parse file %s" % self.file
209               
210            # Catch struct name
211            if line.count("typedef struct")>0:
212                # We are entering a struct block
213                self.inStruct = True
214           
215            if self.inStruct and line.count("}")>0:
216                # We are exiting a struct block
217                self.inStruct = False
218   
219                # Catch the name of the struct
220                index = line.index("}")
221                #toks = string.split(line[index+1:],";")
222                toks = line[index+1:].split(";")
223                # Catch pointer definition
224                #toks2 = string.split(toks[0],',')
225                toks2 = toks[0].split(',')
226                self.structName = toks2[0].lstrip().rstrip()
227           
228               
229           
230            # Catch struct content
231            key = "[DEFAULT]"
232            if self.inStruct and line.count(key)>0:
233                # Found a new parameter
234                try:
235                    index = line.index(key)
236                    toks = line[index:].split("=")
237                    toks2 = toks[2].split()
238                    val = float(toks2[0])
239                    self.params[toks[1]] = val
240                    #self.pythonClass = toks[1].lstrip().rstrip()
241                    units = ""
242                    if len(toks2) >= 2:
243                        units = toks2[1]
244                    self.default_list += "         %-15s = %s %s\n" % \
245                        (toks[1], val, units)
246                   
247                    # Check for min and max
248                    min = "None"
249                    max = "None"
250                    if len(toks2) == 4:
251                        min = toks2[2]
252                        max = toks2[3]
253                   
254                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
255                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
256                except:
257                    raise ValueError, "Could not parse input file %s \n  %s" % \
258                        (self.file, sys.exc_value)
259               
260               
261            # Catch need for numerical calculations
262            key = "CalcParameters calcPars"
263            if line.count(key)>0:
264                self.modelCalcFlag = True
265               
266            # Catch list of dispersed parameters
267            key = "[DISP_PARAMS]"
268            if line.count(key)>0:
269                try:
270                    index = line.index(key)
271                    toks = line[index:].split("=")
272                    list_str = toks[1].lstrip().rstrip()
273                    self.disp_params = list_str.split(',')
274                except:
275                    raise ValueError, "Could not parse file %s" % self.file
276               
277       
278               
279    def write_c_wrapper(self):
280        """ Writes the C file to create the python extension class
281            The file is written in C[PYTHONCLASS].c
282        """
283       
284        file = open("C"+self.pythonClass+'.cpp', 'w')
285        template = open("classTemplate.txt", 'r')
286       
287        tmp_buf = template.read()
288        #tmp_lines = string.split(tmp_buf,'\n')
289        tmp_lines = tmp_buf.split('\n')
290       
291        for tmp_line in tmp_lines:
292           
293            # Catch class name
294            newline = self.replaceToken(tmp_line, 
295                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
296            #Catch model description
297            #newline = self.replaceToken(tmp_line,
298            #                            "[DESCRIPTION]", self.description)
299            # Catch C model name
300            newline = self.replaceToken(newline, 
301                                        "[CMODEL]", self.pythonClass)
302           
303            # Catch class name
304            newline = self.replaceToken(newline, 
305                                        "[MODELSTRUCT]", self.structName)
306           
307            # Dictionary initialization
308            param_str = "// Initialize parameter dictionary\n"           
309            for par in self.params:
310                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
311                    (par, self.params[par])
312
313            param_str += "        // Initialize dispersion / averaging parameter dict\n"
314            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
315            param_str += "        PyObject * disp_dict;\n"
316            for par in self.disp_params:
317                par = par.strip()
318                param_str += "        disp_dict = PyDict_New();\n"
319                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
320                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
321               
322            # Initialize dispersion object dictionnary
323            param_str += "\n"
324           
325               
326            newline = self.replaceToken(newline,
327                                        "[INITDICTIONARY]", param_str)
328           
329            # Read dictionary
330            param_str = "    // Reader parameter dictionary\n"
331            for par in self.params:
332                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
333                    (par, par)
334                   
335            param_str += "    // Read in dispersion parameters\n"
336            param_str += "    PyObject* disp_dict;\n"
337            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
338            for par in self.disp_params:
339                par = par.strip()
340                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
341                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
342               
343            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
344               
345            # Name of .c file
346            #toks = string.split(self.file,'.')
347            basename = os.path.basename(self.file)
348            toks = basename.split('.')
349            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
350           
351            # Include file
352            basename = os.path.basename(self.file)
353            newline = self.replaceToken(newline, 
354                                        "[INCLUDE_FILE]", basename)           
355               
356            # Numerical calcs dealloc
357            dealloc_str = "\n"
358            if self.modelCalcFlag:
359                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
360            newline = self.replaceToken(newline, 
361                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
362               
363            # Numerical calcs init
364            init_str = "\n"
365            if self.modelCalcFlag:
366                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
367            newline = self.replaceToken(newline, 
368                                        "[NUMERICAL_INIT]", init_str)     
369               
370            # Numerical calcs reset
371            reset_str = "\n"
372            if self.modelCalcFlag:
373                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
374            newline = self.replaceToken(newline, 
375                                        "[NUMERICAL_RESET]", reset_str)     
376               
377            # Setting dispsertion weights
378            set_weights = "    // Ugliness necessary to go from python to C\n"
379            set_weights = "    // TODO: refactor this\n"
380            for par in self.disp_params:
381                par = par.strip()
382                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
383                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
384                set_weights += "    } else"
385            newline = self.replaceToken(newline, 
386                                        "[SET_DISPERSION]", set_weights)     
387           
388            # Write new line to the wrapper .c file
389            file.write(newline+'\n')
390           
391           
392        file.close()
393       
394    def write_python_wrapper(self):
395        """ Writes the python file to create the python extension class
396            The file is written in ../[PYTHONCLASS].py
397        """
398       
399        file = open("../"+self.pythonClass+'.py', 'w')
400        template = open("modelTemplate.txt", 'r')
401       
402        tmp_buf = template.read()
403        tmp_lines = tmp_buf.split('\n')
404       
405        for tmp_line in tmp_lines:
406           
407            # Catch class name
408            newline = self.replaceToken(tmp_line, 
409                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
410           
411            # Catch class name
412            newline = self.replaceToken(newline, 
413                                        "[PYTHONCLASS]", self.pythonClass)
414           
415            # Include file
416            newline = self.replaceToken(newline, 
417                                        "[INCLUDE_FILE]", self.file)   
418                   
419            # Include file
420            newline = self.replaceToken(newline, 
421                                        "[DEFAULT_LIST]", self.default_list)
422            # model description
423            newline = self.replaceToken(newline, 
424                                        "[DESCRIPTION]", self.description)
425            # Parameter details
426            newline = self.replaceToken(newline, 
427                                        "[PAR_DETAILS]", self.details)
428           
429            # fixed list  details
430            newline = self.replaceToken(newline, 
431                                        "[FIXED]",str(self.fixed))
432            # non-fittable list  details
433            newline = self.replaceToken(newline, 
434                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
435            ## parameters with orientation
436       
437            newline = self.replaceToken(newline, 
438                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
439           
440            # Write new line to the wrapper .c file
441            file.write(newline+'\n')
442               
443        file.close()
444       
445       
446    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
447        """ Replace a token in the template file
448            @param line: line of text to inspect
449            @param key: token to look for
450            @param value: string value to replace the token with
451            @return: new string value
452        """
453        lenkey = len(key)
454        newline = line
455       
456        while newline.count(key)>0:
457            index = newline.index(key)
458            newline = newline[:index]+value+newline[index+lenkey:]
459       
460        return newline
461       
462       
463# main
464if __name__ == '__main__':
465    if len(sys.argv)>1:
466        print "Will look for file %s" % sys.argv[1]
467    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
468        app = WrapperGenerator(sys.argv[1])
469    else:
470        app = WrapperGenerator("test.h")
471    app.read()
472    app.write_c_wrapper()
473    app.write_python_wrapper()
474    print app
475   
476# End of file       
Note: See TracBrowser for help on using the repository browser.