source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ f686259

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f686259 was 8389beb, checked in by Jae Cho <jhjcho@…>, 12 years ago

some pylint cleanups

  • Property mode set to 100644
File size: 23.0 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46        /// 1D scattering function
47        double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49        /// 2D scattering function
50        double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        # Model category
93        self.category = None
94        # Whether model belongs to multifunc
95        self.is_multifunc = False
96        ## output directory for wrappers
97        self.output_dir = output_dir
98        self.c_wrapper_dir = c_wrapper_dir
99
100       
101       
102    def __repr__(self):
103        """ Simple output for printing """
104       
105        rep  = "\n Python class: %s\n\n" % self.pythonClass
106        rep += "  struc name: %s\n\n" % self.structName
107        rep += "  params:     %s\n\n" % self.params
108        rep += "  description:    %s\n\n" % self.description
109        rep += "  Fittable parameters:     %s\n\n"% self.fixed
110        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
111        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
112        return rep
113       
114    def read(self):
115        """ Reads in the .h file to catch parameters of the wrapper """
116       
117        # Check if the file is there
118        if not os.path.isfile(self.file):
119            raise ValueError, "File %s is not a regular file" % self.file
120       
121        # Read file
122        f = open(self.file,'r')
123        buf = f.read()
124       
125        self.default_list = "List of default parameters:\n"
126        #lines = string.split(buf,'\n')
127        lines = buf.split('\n')
128        self.details  = "## Parameter details [units, min, max]\n"
129        self.details += "        self.details = {}\n"
130       
131        #open item in this case Fixed
132        text='text'
133        key2="<%s>"%text.lower()
134        # close an item in this case fixed
135        text='TexT'
136        key3="</%s>"%text.lower()
137       
138        ## Catch fixed parameters
139        key = "[FIXED]"
140        try:
141            self.fixed= lineparser.readhelper(lines, key, 
142                                              key2, key3, file=self.file)
143        except:
144           raise   
145        ## Catch non-fittable parameters parameters
146        key = "[NON_FITTABLE_PARAMS]"
147        try:
148            self.non_fittable= lineparser.readhelper(lines, key, key2,
149                                                     key3, file=self.file)
150        except:
151           raise   
152
153        ## Catch parameters with orientation
154        key = "[ORIENTATION_PARAMS]"   
155        try:
156            self.orientation_params = lineparser.readhelper(lines, key, 
157                                                    key2, key3, file=self.file)
158        except:
159           raise 
160        ## Catch Description
161        key = "[DESCRIPTION]"
162       
163        find_description = False
164        temp=""
165        for line in lines:
166            if line.count(key)>0 :
167               
168                try:
169                    find_description= True
170                    index = line.index(key)
171                    toks = line[index:].split("=",1 )
172                    temp=toks[1].lstrip().rstrip()
173                    text='text'
174                    key2="<%s>"%text.lower()
175                    if re.match(key2,temp)!=None:
176   
177                        toks2=temp.split(key2,1)
178                        self.description=toks2[1]
179                        text='text'
180                        key2="</%s>"%text.lower()
181                        if re.search(key2,toks2[1])!=None:
182                            temp=toks2[1].split(key2,1)
183                            self.description=temp[0]
184                            break
185                     
186                    else:
187                        self.description=temp
188                        break
189                except:
190                     raise ValueError, "Could not parse file %s" % self.file
191            elif find_description:
192                text='text'
193                key2="</%s>"%text.lower()
194                if re.search(key2,line)!=None:
195                    tok=line.split(key2,1)
196                    temp=tok[0].split("//",1)
197                    self.description+=tok[1].lstrip().rstrip()
198                    break
199                else:
200                    if re.search("//",line)!=None:
201                        temp=line.split("//",1)
202                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
203                       
204                    else:
205                        self.description+='\n\t\t'+line.lstrip().rstrip()
206                   
207               
208     
209        for line in lines:
210           
211            # Catch class name
212            key = "[PYTHONCLASS]"
213            if line.count(key)>0:
214                try:
215                    index = line.index(key)
216                    toks = line[index:].split("=" )
217                    self.pythonClass = toks[1].lstrip().rstrip()
218
219                except:
220                    raise ValueError, "Could not parse file %s" % self.file
221               
222            key = "[CATEGORY]"
223            if line.count(key)>0:
224                try:
225                    index = line.index(key)
226                    toks = line[index:].split("=")
227                    self.category = toks[1].lstrip().rstrip()
228
229                except:
230                    raise ValueError, "Could not parse file %s" % self.file
231
232            # is_multifunc
233            key = "[MULTIPLICITY_INFO]"
234            if line.count(key) > 0:
235                self.is_multifunc = True
236                try:
237                    index = line.index(key)
238                    toks = line[index:].split("=")
239                    self.multiplicity_info = toks[1].lstrip().rstrip()
240                except:
241                    raise ValueError, "Could not parse file %s" % self.file
242
243            # Catch struct name
244            # C++ class definition
245            if line.count("class")>0:
246                # We are entering a class definition
247                self.inParDefs = True
248                self.foundCPP = True
249               
250            # Old-Style C struct definition
251            if line.count("typedef struct")>0:
252                # We are entering a struct block
253                self.inParDefs = True
254                self.inStruct = True
255           
256            if self.inParDefs and line.count("}")>0:
257                # We are exiting a struct block
258                self.inParDefs = False
259               
260                if self.inStruct:
261                    self.inStruct = False
262                    # Catch the name of the struct
263                    index = line.index("}")
264                    toks = line[index+1:].split(";")
265                    # Catch pointer definition
266                    toks2 = toks[0].split(',')
267                    self.structName = toks2[0].lstrip().rstrip()
268           
269            # Catch struct content
270            key = "[DEFAULT]"
271            if self.inParDefs and line.count(key)>0:
272                # Found a new parameter
273                try:
274                    index = line.index(key)
275                    toks = line[index:].split("=")
276                    toks2 = toks[2].split()
277                    val = float(toks2[0])
278                    self.params[toks[1]] = val
279                    #self.pythonClass = toks[1].lstrip().rstrip()
280                    units = ""
281                    if len(toks2) >= 2:
282                        units = toks2[1]
283                    self.default_list += "         %-15s = %s %s\n" % \
284                        (toks[1], val, units)
285                   
286                    # Check for min and max
287                    min = "None"
288                    max = "None"
289                    if len(toks2) == 4:
290                        min = toks2[2]
291                        max = toks2[3]
292                   
293                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
294                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
295                except:
296                    raise ValueError, "Could not parse input file %s \n  %s" % \
297                        (self.file, sys.exc_value)
298               
299               
300            # Catch need for numerical calculations
301            key = "CalcParameters calcPars"
302            if line.count(key)>0:
303                self.modelCalcFlag = True
304               
305            # Catch list of dispersed parameters
306            key = "[DISP_PARAMS]"
307            if line.count(key)>0:
308                try:
309                    index = line.index(key)
310                    toks = line[index:].split("=")
311                    list_str = toks[1].lstrip().rstrip()
312                    self.disp_params = list_str.split(',')
313                except:
314                    raise ValueError, "Could not parse file %s" % self.file
315
316    def write_c_wrapper(self):
317        """ Writes the C file to create the python extension class
318            The file is written in C[PYTHONCLASS].c
319        """
320        file_path = os.path.join(self.c_wrapper_dir, 
321                                 "C"+self.pythonClass+'.cpp')
322        file = open(file_path, 'w')
323       
324        template = open(os.path.join(os.path.dirname(__file__), 
325                                     "classTemplate.txt"), 'r')
326       
327        tmp_buf = template.read()
328        #tmp_lines = string.split(tmp_buf,'\n')
329        tmp_lines = tmp_buf.split('\n')
330       
331        for tmp_line in tmp_lines:
332           
333            # Catch class name
334            newline = self.replaceToken(tmp_line, 
335                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
336            #Catch model description
337            #newline = self.replaceToken(tmp_line,
338            #                            "[DESCRIPTION]", self.description)
339            # Catch C model name
340            newline = self.replaceToken(newline, 
341                                        "[CMODEL]", self.pythonClass)
342           
343            # Catch class name
344            newline = self.replaceToken(newline, 
345                                        "[MODELSTRUCT]", self.structName)
346
347            # Sort model initialization based on multifunc
348            if(self.is_multifunc):
349                line = "int level = 1;\nPyArg_ParseTuple(args,\"i\",&level);\n"
350                line += "self->model = new " + self.pythonClass + "(level);"
351            else:
352                line = "self->model = new " + self.pythonClass + "();"
353   
354            newline = self.replaceToken(newline,"[INITIALIZE_MODEL]",
355                                            line)
356           
357            # Dictionary initialization
358            param_str = "// Initialize parameter dictionary\n"           
359            for par in self.params:
360                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
361                    (par, self.params[par])
362
363            if len(self.disp_params)>0:
364                param_str += "        // Initialize dispersion / averaging parameter dict\n"
365                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
366                param_str += "        PyObject * disp_dict;\n"
367                for par in self.disp_params:
368                    par = par.strip()
369                    param_str += "        disp_dict = PyDict_New();\n"
370                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
371                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
372               
373            # Initialize dispersion object dictionnary
374            param_str += "\n"
375           
376               
377            newline = self.replaceToken(newline,
378                                        "[INITDICTIONARY]", param_str)
379           
380            # Read dictionary
381            param_str = "    // Reader parameter dictionary\n"
382            for par in self.params:
383                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
384                    (par, par)
385                   
386            if len(self.disp_params)>0:
387                param_str += "    // Read in dispersion parameters\n"
388                param_str += "    PyObject* disp_dict;\n"
389                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
390                for par in self.disp_params:
391                    par = par.strip()
392                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
393                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
394               
395            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
396               
397            # Name of .c file
398            #toks = string.split(self.file,'.')
399            basename = os.path.basename(self.file)
400            toks = basename.split('.')
401            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
402           
403            # Include file
404            basename = os.path.basename(self.file)
405            newline = self.replaceToken(newline, 
406                                        "[INCLUDE_FILE]", self.file) 
407            if self.foundCPP:
408                newline = self.replaceToken(newline, 
409                            "[C_INCLUDE_FILE]", "") 
410                newline = self.replaceToken(newline, 
411                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
412            else: 
413                newline = self.replaceToken(newline, 
414                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
415                newline = self.replaceToken(newline, 
416                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
417               
418            # Numerical calcs dealloc
419            dealloc_str = "\n"
420            if self.modelCalcFlag:
421                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
422            newline = self.replaceToken(newline, 
423                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
424               
425            # Numerical calcs init
426            init_str = "\n"
427            if self.modelCalcFlag:
428                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
429            newline = self.replaceToken(newline, 
430                                        "[NUMERICAL_INIT]", init_str)     
431               
432            # Numerical calcs reset
433            reset_str = "\n"
434            if self.modelCalcFlag:
435                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
436            newline = self.replaceToken(newline, 
437                                        "[NUMERICAL_RESET]", reset_str)     
438               
439            # Setting dispsertion weights
440            set_weights = "    // Ugliness necessary to go from python to C\n"
441            set_weights = "    // TODO: refactor this\n"
442            for par in self.disp_params:
443                par = par.strip()
444                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
445                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
446                set_weights += "    } else"
447            newline = self.replaceToken(newline, 
448                                        "[SET_DISPERSION]", set_weights)     
449           
450            # Write new line to the wrapper .c file
451            file.write(newline+'\n')
452           
453           
454        file.close()
455       
456    def write_python_wrapper(self):
457        """ Writes the python file to create the python extension class
458            The file is written in ../[PYTHONCLASS].py
459        """
460        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
461        file = open(file_path, 'w')
462        template = open(os.path.join(os.path.dirname(__file__), 
463                                     "modelTemplate.txt"), 'r')
464       
465        tmp_buf = template.read()
466        tmp_lines = tmp_buf.split('\n')
467       
468        for tmp_line in tmp_lines:
469           
470            # Catch class name
471            newline = self.replaceToken(tmp_line, 
472                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
473           
474            # Catch class name
475            newline = self.replaceToken(newline, 
476                                        "[PYTHONCLASS]", self.pythonClass)
477           
478            # Include file
479            newline = self.replaceToken(newline, 
480                                        "[INCLUDE_FILE]", self.file)   
481                   
482            # Include file
483            newline = self.replaceToken(newline, 
484                                        "[DEFAULT_LIST]", self.default_list)
485            # model description
486            newline = self.replaceToken(newline, 
487                                        "[DESCRIPTION]", self.description)
488            # Parameter details
489            newline = self.replaceToken(newline, 
490                                        "[PAR_DETAILS]", self.details)
491           
492            # Call base constructor
493            if self.is_multifunc:
494                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
495                    'C' + self.pythonClass + \
496                    ".__init__(self,multfactor)\n\tself.is_multifunc = True")
497                newline = self.replaceToken(newline,"[MULTIPLICITY_INFO]", 
498                                            self.multiplicity_info)
499            else:
500                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
501                    'C' + self.pythonClass + \
502                    ".__init__(self)\n        self.is_multifunc = False")
503                newline = self.replaceToken(newline, 
504                                            "[MULTIPLICITY_INFO]","None")
505
506           
507            # fixed list  details
508            fixed_str = str(self.fixed)
509            fixed_str = fixed_str.replace(', ', ',\n                      ')
510            newline = self.replaceToken(newline, "[FIXED]",fixed_str)
511           
512            # non-fittable list details
513            pars_str = str(self.non_fittable)
514            pars_str = pars_str.replace(', ', 
515                                        ',\n                             ')
516            newline = self.replaceToken(newline, 
517                                        "[NON_FITTABLE_PARAMS]",
518                                        pars_str)
519           
520            ## parameters with orientation
521            oriented_str = str(self.orientation_params)
522            formatted_endl = ',\n                                   '
523            oriented_str = oriented_str.replace(', ', formatted_endl)
524            newline = self.replaceToken(newline, 
525                               "[ORIENTATION_PARAMS]", oriented_str)
526           
527            if self.category:
528                newline = self.replaceToken(newline,"[CATEGORY]"
529                                            ,'"' + self.category + '"')
530            else:
531                newline = self.replaceToken(newline,"[CATEGORY]",
532                                            "None")
533           
534
535
536            # Write new line to the wrapper .c file
537            file.write(newline+'\n')
538               
539        file.close()
540       
541       
542    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
543        """ Replace a token in the template file
544            @param line: line of text to inspect
545            @param key: token to look for
546            @param value: string value to replace the token with
547            @return: new string value
548        """
549        lenkey = len(key)
550        newline = line
551       
552        while newline.count(key)>0:
553            index = newline.index(key)
554            newline = newline[:index]+value+newline[index+lenkey:]
555       
556        return newline
557       
558    def getModelName(self):
559        return self.pythonClass
560       
561
562
563# main
564if __name__ == '__main__':
565    if len(sys.argv)>1:
566        print "Will look for file %s" % sys.argv[1]
567        app = WrapperGenerator(sys.argv[1])
568    else:
569        app = WrapperGenerator("test.h")
570    app.read()
571    app.write_c_wrapper()
572    app.write_python_wrapper()
573    print app
574   
575# End of file       
Note: See TracBrowser for help on using the repository browser.