source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ c9636f7

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c9636f7 was da3dae3, checked in by Gervaise Alina <gervyh@…>, 16 years ago

wrappergenerator modified

  • Property mode set to 100644
File size: 24.8 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6def split_list(separator, mylist, n=0):
7    """
8        @return a list of string without white space of separator
9        @param separator: the string to remove
10    """
11    list=[]
12    for item in mylist:
13        if re.search( separator,item)!=None:
14            if n >0:
15                word =re.split(separator,item,int(n))
16            else:
17                word =re.split( separator,item)
18            for new_item in word: 
19                if new_item.lstrip().rstrip() !='':
20                    list.append(new_item.lstrip().rstrip())
21    return list
22def split_text(separator, string1, n=0):
23    """
24        @return a list of string without white space of separator
25        @param separator: the string to remove
26    """
27    list=[]
28    if re.search( separator,string1)!=None:
29        if n >0:
30            word =re.split(separator,string1,int(n))
31        else:
32            word =re.split(separator,string1)
33        for item in word: 
34            if item.lstrip().rstrip() !='':
35                list.append(item.lstrip().rstrip())
36    return list
37def look_for_tag( string1,begin, end=None ):
38    """
39        @note: this method  remove the begin and end tags given by the user
40        from the string .
41        @param begin: the initial tag
42        @param end: the final tag
43        @param string: the string to check
44        @return: begin_flag==True if begin was found,
45         end_flag==if end was found else return false, false
46         
47    """
48    begin_flag= False
49    end_flag= False
50    if  re.search( begin,string1)!=None:
51        begin_flag= True
52    if end !=None:
53        if  re.search(end,string1)!=None:
54            end_flag= True
55    return begin_flag, end_flag
56
57class WrapperGenerator:
58    """ Python wrapper generator for C models
59   
60        The developer must provide a header file describing
61        the new model.
62       
63        To provide the name of the Python class to be
64        generated, the .h file must contain the following
65        string in the comments:
66       
67        // [PYTHONCLASS] = my_model
68       
69        where my_model must be replaced by the name of the
70        class that you want to import from sans.models.
71        (example: [PYTHONCLASS] = MyModel
72          will create a class MyModel in sans.models.MyModel.
73          It will also create a class CMyModel in
74          sans_extension.c_models.)
75         
76        Also in comments, each parameter of the params
77        dictionary must be declared with a default value
78        in the following way:
79       
80        // [DEFAULT]=param_name=default_value
81       
82        (example:
83            //  [DEFAULT]=radius=20.0
84        )
85         
86        See cylinder.h for an example.
87       
88       
89        A .c file corresponding to the .h file should also
90        be provided (example: my_model.h, my_model.c).
91   
92        The .h file should define two function definitions. For example,
93        cylinder.h defines the following:
94       
95            /// 1D scattering function
96            double cylinder_analytical_1D(CylinderParameters *pars, double q);
97           
98            /// 2D scattering function
99            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
100           
101        The .c file implements those functions.
102       
103        @author: Mathieu Doucet / UTK
104        @contact: mathieu.doucet@nist.gov
105    """
106   
107    def __init__(self, filename):
108        """ Initialization """
109       
110        ## Name of .h file to generate wrapper from
111        self.file = filename
112       
113        # Info read from file
114       
115        ## Name of python class to write
116        self.pythonClass = None
117        ## Parser in struct section
118        self.inStruct = False
119        ## Name of struct for the c object
120        self.structName = None
121        ## Dictionary of parameters
122        self.params = {}
123        ## ModelCalculation module flag
124        self.modelCalcFlag = False
125        ## List of default parameters (text)
126        self.default_list = ""
127        ## Dictionary of units
128        self.details = ""
129        ## List of dispersed parameters
130        self.disp_params = []
131        #model description
132        self.description=''
133        # paramaters for fittable
134        self.fixed= []
135       
136    def __repr__(self):
137        """ Simple output for printing """
138       
139        rep  = "Python class: %s\n" % self.pythonClass
140        rep += "  struc name: %s\n" % self.structName
141        rep += "  params:     %s\n" % self.params
142        rep += "  description:     %s\n" % self.description
143        rep += "  fittable paramaters list %s\n"% self.fixed
144        return rep
145       
146    def read(self):
147        """ Reads in the .h file to catch parameters of the wrapper """
148       
149        # Check if the file is there
150        if not os.path.isfile(self.file):
151            raise ValueError, "File %s is not a regular file" % self.file
152       
153        # Read file
154        f = open(self.file,'r')
155        buf = f.read()
156       
157        self.default_list = "List of default parameters:\n"
158        #lines = string.split(buf,'\n')
159        lines = buf.split('\n')
160        self.details  = "## Parameter details [units, min, max]\n"
161        self.details += "        self.details = {}\n"
162         # Catch Description
163        key = "[FIXED]"
164        #open item in this case Fixed
165        text='text'
166        key2="<%s>"%text.lower()
167        # close an item in this case fixed
168        text='TexT'
169        key3="</%s>"%text.lower()
170        temp=""
171        # flag to found key
172        find_fixed= False
173        find_key2=False
174        find_key3=False
175       
176        for line in lines:
177            if line.count(key)>0 :#[FIXED]= .....
178                try:
179                    find_fixed= True
180                    index = line.index(key)
181                    toks  = line[index:].split("=",1 )
182                    temp  = toks[1].lstrip().rstrip()
183                    find_key2, find_key3=look_for_tag( string1=temp,begin=key2, end=key3 )
184                    if find_key2 and find_key3:
185                        #print"Found both:find_key2, find_key3", find_key2, find_key3
186                        temp1=[]
187                        temp2=[]
188                        temp3=[]
189                        temp4=[]
190                        temp1=split_text(separator=key2, string1=temp)
191                        temp2=split_list(separator=key3, mylist=temp1)
192                        temp3=split_list(separator=';', mylist=temp2)
193                        temp4=split_list(separator=',', mylist=temp3)
194                        self.fixed= temp3 + temp4
195                        break
196                       
197                    elif find_key2 and not find_key3:
198                        #print"Find  find_key2 not find_key3", find_key2, find_key3
199                        temp1=[]
200                        temp2=[]
201                        temp3=[]
202                        temp4=[]
203                        temp1=split_text(separator=key2, string1=temp)
204                        temp3=split_list(separator=';', mylist=temp1)
205                        temp4=split_list(separator=',', mylist=temp3)
206                        #print "found only key2 , temp4", temp3,temp4
207                        if len(temp3+ temp4)==0:# [FIXED]=  only one param
208                            self.fixed+= temp1
209                        self.fixed+=temp3+temp4 #
210                    elif not (find_key2 and find_key3) :
211                        #print"Find nothing find_key2 not find_key3", find_key2, find_key3
212                        temp3=[]
213                        temp4=[]
214                        if look_for_tag( string1=temp,begin=";")[0]:# split ";" first
215                            temp3=split_text(separator=';',string1=temp)
216                            temp4=split_list(separator=',', mylist=temp3)
217                        else:
218                            temp3=split_text(separator=',',string1=temp)# slip "," first
219                            temp4=split_list(separator=';', mylist=temp3)
220                        if len(temp3+ temp4)==0:# [FIXED]=  only one param
221                            self.fixed= [temp]
222                        self.fixed+=temp3+temp4 #
223                        break
224                except:
225                    raise ValueError, "Could not parse file %s" % self.file
226           
227            elif find_fixed:
228                #looking only for key3
229                if not find_key2:
230                    raise ValueError, "Could not parse file %s" % self.file
231                find_key3=look_for_tag( string1=line,begin=key3)[0]# split "</text>" first
232                #print "find_key3",find_key3,line,key3
233                if find_key3:
234                    temp1=[]
235                    temp2=[]
236                    temp3=[]
237                    temp4=[]
238                    temp5=[]
239                   
240                    temp1=split_text(separator=key3, string1=line)
241                    temp2=split_list(separator='//',mylist=temp1)
242                    temp5=split_list(separator="\*",mylist=temp1)
243                   
244                    if len(temp5)>0:
245                        temp3=split_list(separator=';',mylist=temp5)
246                        temp4=split_list(separator=',', mylist=temp5)
247                    elif len(temp2)>0:
248                        temp3=split_list(separator=';',mylist=temp2)
249                        temp4=split_list(separator=',', mylist=temp2)
250                    else:
251                        temp3=split_list(separator=';',mylist=temp1)
252                        temp4=split_list(separator=',', mylist=temp1)
253                   
254                    if len(temp3+ temp4)==0:# [FIXED]=  only one param
255                        self.fixed+= temp1
256                    self.fixed+=temp3+temp4 #   
257                    break 
258                else:
259                    temp2=split_text(separator='//',string1=line)
260                    temp5=split_text(separator="\*",string1=line)
261                    if len(temp5)>0:
262                        temp3=split_list(separator=';',mylist=temp5)
263                        temp4=split_list(separator=',', mylist=temp5)
264                    elif len(temp2)>0:
265                        temp3=split_list(separator=';',mylist=temp2)
266                        temp4=split_list(separator=',', mylist=temp2)
267                    else:
268                        if look_for_tag( string1=line,begin=";")[0]:# split ";" first
269                            temp3=split_text(separator=';',string1=line)
270                            temp4=split_list(separator=',', mylist=temp3)
271                        else:
272                            temp3=split_text(separator=',',string1=line)# slip "," first
273                            temp4=split_list(separator=';', mylist=temp3)
274                    if len(temp3+ temp4)==0:# [FIXED]=  only one param
275                        self.fixed= [line.lstrip().rstrip()]
276                    self.fixed+=temp3+temp4 #
277                   
278               
279                   
280        # Catch Description
281        key = "[DESCRIPTION]"
282        find_description= 0
283        temp=""
284        for line in lines:
285            if line.count(key)>0 :
286               
287                try:
288                    find_description= 1
289                    index = line.index(key)
290                    toks = line[index:].split("=",1 )
291                    temp=toks[1].lstrip().rstrip()
292                    text='text'
293                    key2="<%s>"%text.lower()
294                    if re.match(key2,temp)!=None:
295                        #index2 = line.index(key2)
296                        #temp = temp[index2:]
297                        toks2=temp.split(key2,1)
298                        self.description=toks2[1]
299                        text='text'
300                        key2="</%s>"%text.lower()
301                        if re.search(key2,toks2[1])!=None:
302                            temp=toks2[1].split(key2,1)
303                            self.description=temp[0]
304                            break
305                        #print self.description
306                    else:
307                        self.description=temp
308                        break
309                except:
310                     raise
311                     raise ValueError, "Could not parse file %s" % self.file
312            elif find_description==1:
313                text='text'
314                key2="</%s>"%text.lower()
315                #print "second line",line,key2,re.search(key2,line)
316                if re.search(key2,line)!=None:
317                    tok=line.split(key2,1)
318                    temp=tok[0].split("//",1)
319                    self.description+=tok[1].lstrip().rstrip()
320                    break
321                else:
322                    #if re.search("*",line)!=None:
323                    #    temp=line.split("*",1)
324                    #    self.description+='\n'+temp[1].lstrip().rstrip()
325                    if re.search("//",line)!=None:
326                        temp=line.split("//",1)
327                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
328                       
329                    else:
330                        self.description+='\n\t\t'+line.lstrip().rstrip()
331                   
332               
333               
334        for line in lines:
335           
336            # Catch class name
337            key = "[PYTHONCLASS]"
338            if line.count(key)>0:
339                try:
340                    index = line.index(key)
341                    #toks = string.split( line[index:], "=" )
342                    toks = line[index:].split("=" )
343                    self.pythonClass = toks[1].lstrip().rstrip()
344                except:
345                    raise ValueError, "Could not parse file %s" % self.file
346               
347            # Catch struct name
348            if line.count("typedef struct")>0:
349                # We are entering a struct block
350                self.inStruct = True
351           
352            if self.inStruct and line.count("}")>0:
353                # We are exiting a struct block
354                self.inStruct = False
355   
356                # Catch the name of the struct
357                index = line.index("}")
358                #toks = string.split(line[index+1:],";")
359                toks = line[index+1:].split(";")
360                # Catch pointer definition
361                #toks2 = string.split(toks[0],',')
362                toks2 = toks[0].split(',')
363                self.structName = toks2[0].lstrip().rstrip()
364           
365               
366           
367            # Catch struct content
368            key = "[DEFAULT]"
369            if self.inStruct and line.count(key)>0:
370                # Found a new parameter
371                try:
372                    index = line.index(key)
373                    toks = line[index:].split("=")
374                    toks2 = toks[2].split()
375                    val = float(toks2[0])
376                    self.params[toks[1]] = val
377                    #self.pythonClass = toks[1].lstrip().rstrip()
378                    units = ""
379                    if len(toks2) >= 2:
380                        units = toks2[1]
381                    self.default_list += "         %-15s = %s %s\n" % \
382                        (toks[1], val, units)
383                   
384                    # Check for min and max
385                    min = "None"
386                    max = "None"
387                    if len(toks2) == 4:
388                        min = toks2[2]
389                        max = toks2[3]
390                   
391                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
392                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
393                except:
394                    raise ValueError, "Could not parse input file %s \n  %s" % \
395                        (self.file, sys.exc_value)
396               
397               
398            # Catch need for numerical calculations
399            key = "CalcParameters calcPars"
400            if line.count(key)>0:
401                self.modelCalcFlag = True
402               
403            # Catch list of dispersed parameters
404            key = "[DISP_PARAMS]"
405            if line.count(key)>0:
406                try:
407                    index = line.index(key)
408                    toks = line[index:].split("=")
409                    list_str = toks[1].lstrip().rstrip()
410                    self.disp_params = list_str.split(',')
411                except:
412                    raise ValueError, "Could not parse file %s" % self.file
413               
414       
415               
416    def write_c_wrapper(self):
417        """ Writes the C file to create the python extension class
418            The file is written in C[PYTHONCLASS].c
419        """
420       
421        file = open("C"+self.pythonClass+'.cpp', 'w')
422        template = open("classTemplate.txt", 'r')
423       
424        tmp_buf = template.read()
425        #tmp_lines = string.split(tmp_buf,'\n')
426        tmp_lines = tmp_buf.split('\n')
427       
428        for tmp_line in tmp_lines:
429           
430            # Catch class name
431            newline = self.replaceToken(tmp_line, 
432                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
433            #Catch model description
434            #newline = self.replaceToken(tmp_line,
435            #                            "[DESCRIPTION]", self.description)
436            # Catch C model name
437            newline = self.replaceToken(newline, 
438                                        "[CMODEL]", self.pythonClass)
439           
440            # Catch class name
441            newline = self.replaceToken(newline, 
442                                        "[MODELSTRUCT]", self.structName)
443           
444            # Dictionary initialization
445            param_str = "// Initialize parameter dictionary\n"           
446            for par in self.params:
447                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
448                    (par, self.params[par])
449
450            param_str += "        // Initialize dispersion / averaging parameter dict\n"
451            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
452            param_str += "        PyObject * disp_dict;\n"
453            for par in self.disp_params:
454                par = par.strip()
455                param_str += "        disp_dict = PyDict_New();\n"
456                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
457                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
458               
459            # Initialize dispersion object dictionnary
460            param_str += "\n"
461           
462               
463            newline = self.replaceToken(newline,
464                                        "[INITDICTIONARY]", param_str)
465           
466            # Read dictionary
467            param_str = "    // Reader parameter dictionary\n"
468            for par in self.params:
469                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
470                    (par, par)
471                   
472            param_str += "    // Read in dispersion parameters\n"
473            param_str += "    PyObject* disp_dict;\n"
474            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
475            for par in self.disp_params:
476                par = par.strip()
477                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
478                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
479               
480            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
481               
482            # Name of .c file
483            #toks = string.split(self.file,'.')
484            basename = os.path.basename(self.file)
485            toks = basename.split('.')
486            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
487           
488            # Include file
489            basename = os.path.basename(self.file)
490            newline = self.replaceToken(newline, 
491                                        "[INCLUDE_FILE]", basename)           
492               
493            # Numerical calcs dealloc
494            dealloc_str = "\n"
495            if self.modelCalcFlag:
496                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
497            newline = self.replaceToken(newline, 
498                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
499               
500            # Numerical calcs init
501            init_str = "\n"
502            if self.modelCalcFlag:
503                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
504            newline = self.replaceToken(newline, 
505                                        "[NUMERICAL_INIT]", init_str)     
506               
507            # Numerical calcs reset
508            reset_str = "\n"
509            if self.modelCalcFlag:
510                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
511            newline = self.replaceToken(newline, 
512                                        "[NUMERICAL_RESET]", reset_str)     
513               
514            # Setting dispsertion weights
515            set_weights = "    // Ugliness necessary to go from python to C\n"
516            set_weights = "    // TODO: refactor this\n"
517            for par in self.disp_params:
518                par = par.strip()
519                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
520                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
521                set_weights += "    } else"
522            newline = self.replaceToken(newline, 
523                                        "[SET_DISPERSION]", set_weights)     
524           
525            # Write new line to the wrapper .c file
526            file.write(newline+'\n')
527           
528           
529        file.close()
530       
531    def write_python_wrapper(self):
532        """ Writes the python file to create the python extension class
533            The file is written in ../[PYTHONCLASS].py
534        """
535       
536        file = open("../"+self.pythonClass+'.py', 'w')
537        template = open("modelTemplate.txt", 'r')
538       
539        tmp_buf = template.read()
540        tmp_lines = tmp_buf.split('\n')
541       
542        for tmp_line in tmp_lines:
543           
544            # Catch class name
545            newline = self.replaceToken(tmp_line, 
546                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
547           
548            # Catch class name
549            newline = self.replaceToken(newline, 
550                                        "[PYTHONCLASS]", self.pythonClass)
551           
552            # Include file
553            newline = self.replaceToken(newline, 
554                                        "[INCLUDE_FILE]", self.file)   
555                   
556            # Include file
557            newline = self.replaceToken(newline, 
558                                        "[DEFAULT_LIST]", self.default_list)
559            # model description
560            newline = self.replaceToken(newline, 
561                                        "[DESCRIPTION]", self.description)
562            # Parameter details
563            newline = self.replaceToken(newline, 
564                                        "[PAR_DETAILS]", self.details)
565           
566            # fixed list  details
567            newline = self.replaceToken(newline, 
568                                        "[FIXED]",str(self.fixed))
569           
570            # Write new line to the wrapper .c file
571            file.write(newline+'\n')
572               
573        file.close()
574       
575       
576    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
577        """ Replace a token in the template file
578            @param line: line of text to inspect
579            @param key: token to look for
580            @param value: string value to replace the token with
581            @return: new string value
582        """
583        lenkey = len(key)
584        newline = line
585       
586        while newline.count(key)>0:
587            index = newline.index(key)
588            newline = newline[:index]+value+newline[index+lenkey:]
589       
590        return newline
591       
592       
593# main
594if __name__ == '__main__':
595    if len(sys.argv)>1:
596        print "Will look for file %s" % sys.argv[1]
597    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
598        app = WrapperGenerator(sys.argv[1])
599    else:
600        app = WrapperGenerator("test.h")
601    app.read()
602    app.write_c_wrapper()
603    app.write_python_wrapper()
604    print app
605   
606# End of file       
Note: See TracBrowser for help on using the repository browser.