source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 8fcf331

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8fcf331 was 0ea247e, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

Fix problem with model template

  • Property mode set to 100644
File size: 20.6 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        ## output directory for wrappers
93        self.output_dir = output_dir
94        self.c_wrapper_dir = c_wrapper_dir
95       
96       
97    def __repr__(self):
98        """ Simple output for printing """
99       
100        rep  = "\n Python class: %s\n\n" % self.pythonClass
101        rep += "  struc name: %s\n\n" % self.structName
102        rep += "  params:     %s\n\n" % self.params
103        rep += "  description:    %s\n\n" % self.description
104        rep += "  Fittable parameters:     %s\n\n"% self.fixed
105        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
106        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
107        return rep
108       
109    def read(self):
110        """ Reads in the .h file to catch parameters of the wrapper """
111       
112        # Check if the file is there
113        if not os.path.isfile(self.file):
114            raise ValueError, "File %s is not a regular file" % self.file
115       
116        # Read file
117        f = open(self.file,'r')
118        buf = f.read()
119       
120        self.default_list = "List of default parameters:\n"
121        #lines = string.split(buf,'\n')
122        lines = buf.split('\n')
123        self.details  = "## Parameter details [units, min, max]\n"
124        self.details += "        self.details = {}\n"
125       
126        #open item in this case Fixed
127        text='text'
128        key2="<%s>"%text.lower()
129        # close an item in this case fixed
130        text='TexT'
131        key3="</%s>"%text.lower()
132       
133        ## Catch fixed parameters
134        key = "[FIXED]"
135        try:
136            self.fixed= lineparser.readhelper(lines, key, 
137                                              key2, key3, file=self.file)
138        except:
139           raise   
140        ## Catch non-fittable parameters parameters
141        key = "[NON_FITTABLE_PARAMS]"
142        try:
143            self.non_fittable= lineparser.readhelper(lines, key, key2,
144                                                     key3, file=self.file)
145        except:
146           raise   
147
148        ## Catch parameters with orientation
149        key = "[ORIENTATION_PARAMS]"   
150        try:
151            self.orientation_params = lineparser.readhelper(lines, key, 
152                                                    key2, key3, file=self.file)
153        except:
154           raise 
155        ## Catch Description
156        key = "[DESCRIPTION]"
157       
158        find_description = False
159        temp=""
160        for line in lines:
161            if line.count(key)>0 :
162               
163                try:
164                    find_description= True
165                    index = line.index(key)
166                    toks = line[index:].split("=",1 )
167                    temp=toks[1].lstrip().rstrip()
168                    text='text'
169                    key2="<%s>"%text.lower()
170                    if re.match(key2,temp)!=None:
171   
172                        toks2=temp.split(key2,1)
173                        self.description=toks2[1]
174                        text='text'
175                        key2="</%s>"%text.lower()
176                        if re.search(key2,toks2[1])!=None:
177                            temp=toks2[1].split(key2,1)
178                            self.description=temp[0]
179                            break
180                     
181                    else:
182                        self.description=temp
183                        break
184                except:
185                     raise ValueError, "Could not parse file %s" % self.file
186            elif find_description:
187                text='text'
188                key2="</%s>"%text.lower()
189                if re.search(key2,line)!=None:
190                    tok=line.split(key2,1)
191                    temp=tok[0].split("//",1)
192                    self.description+=tok[1].lstrip().rstrip()
193                    break
194                else:
195                    if re.search("//",line)!=None:
196                        temp=line.split("//",1)
197                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
198                       
199                    else:
200                        self.description+='\n\t\t'+line.lstrip().rstrip()
201                   
202               
203               
204        for line in lines:
205           
206            # Catch class name
207            key = "[PYTHONCLASS]"
208            if line.count(key)>0:
209                try:
210                    index = line.index(key)
211                    #toks = string.split( line[index:], "=" )
212                    toks = line[index:].split("=" )
213                    self.pythonClass = toks[1].lstrip().rstrip()
214                except:
215                    raise ValueError, "Could not parse file %s" % self.file
216               
217            # Catch struct name
218            # C++ class definition
219            if line.count("class")>0:
220                # We are entering a class definition
221                self.inParDefs = True
222                self.foundCPP = True
223               
224            # Old-Style C struct definition
225            if line.count("typedef struct")>0:
226                # We are entering a struct block
227                self.inParDefs = True
228                self.inStruct = True
229           
230            if self.inParDefs and line.count("}")>0:
231                # We are exiting a struct block
232                self.inParDefs = False
233               
234                if self.inStruct:
235                    self.inStruct = False
236                    # Catch the name of the struct
237                    index = line.index("}")
238                    toks = line[index+1:].split(";")
239                    # Catch pointer definition
240                    toks2 = toks[0].split(',')
241                    self.structName = toks2[0].lstrip().rstrip()
242           
243            # Catch struct content
244            key = "[DEFAULT]"
245            if self.inParDefs and line.count(key)>0:
246                # Found a new parameter
247                try:
248                    index = line.index(key)
249                    toks = line[index:].split("=")
250                    toks2 = toks[2].split()
251                    val = float(toks2[0])
252                    self.params[toks[1]] = val
253                    #self.pythonClass = toks[1].lstrip().rstrip()
254                    units = ""
255                    if len(toks2) >= 2:
256                        units = toks2[1]
257                    self.default_list += "         %-15s = %s %s\n" % \
258                        (toks[1], val, units)
259                   
260                    # Check for min and max
261                    min = "None"
262                    max = "None"
263                    if len(toks2) == 4:
264                        min = toks2[2]
265                        max = toks2[3]
266                   
267                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
268                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
269                except:
270                    raise ValueError, "Could not parse input file %s \n  %s" % \
271                        (self.file, sys.exc_value)
272               
273               
274            # Catch need for numerical calculations
275            key = "CalcParameters calcPars"
276            if line.count(key)>0:
277                self.modelCalcFlag = True
278               
279            # Catch list of dispersed parameters
280            key = "[DISP_PARAMS]"
281            if line.count(key)>0:
282                try:
283                    index = line.index(key)
284                    toks = line[index:].split("=")
285                    list_str = toks[1].lstrip().rstrip()
286                    self.disp_params = list_str.split(',')
287                except:
288                    raise ValueError, "Could not parse file %s" % self.file
289               
290       
291               
292    def write_c_wrapper(self):
293        """ Writes the C file to create the python extension class
294            The file is written in C[PYTHONCLASS].c
295        """
296        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
297        file = open(file_path, 'w')
298       
299        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
300       
301        tmp_buf = template.read()
302        #tmp_lines = string.split(tmp_buf,'\n')
303        tmp_lines = tmp_buf.split('\n')
304       
305        for tmp_line in tmp_lines:
306           
307            # Catch class name
308            newline = self.replaceToken(tmp_line, 
309                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
310            #Catch model description
311            #newline = self.replaceToken(tmp_line,
312            #                            "[DESCRIPTION]", self.description)
313            # Catch C model name
314            newline = self.replaceToken(newline, 
315                                        "[CMODEL]", self.pythonClass)
316           
317            # Catch class name
318            newline = self.replaceToken(newline, 
319                                        "[MODELSTRUCT]", self.structName)
320           
321            # Dictionary initialization
322            param_str = "// Initialize parameter dictionary\n"           
323            for par in self.params:
324                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
325                    (par, self.params[par])
326
327            if len(self.disp_params)>0:
328                param_str += "        // Initialize dispersion / averaging parameter dict\n"
329                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
330                param_str += "        PyObject * disp_dict;\n"
331                for par in self.disp_params:
332                    par = par.strip()
333                    param_str += "        disp_dict = PyDict_New();\n"
334                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
335                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
336               
337            # Initialize dispersion object dictionnary
338            param_str += "\n"
339           
340               
341            newline = self.replaceToken(newline,
342                                        "[INITDICTIONARY]", param_str)
343           
344            # Read dictionary
345            param_str = "    // Reader parameter dictionary\n"
346            for par in self.params:
347                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
348                    (par, par)
349                   
350            if len(self.disp_params)>0:
351                param_str += "    // Read in dispersion parameters\n"
352                param_str += "    PyObject* disp_dict;\n"
353                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
354                for par in self.disp_params:
355                    par = par.strip()
356                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
357                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
358               
359            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
360               
361            # Name of .c file
362            #toks = string.split(self.file,'.')
363            basename = os.path.basename(self.file)
364            toks = basename.split('.')
365            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
366           
367            # Include file
368            basename = os.path.basename(self.file)
369            newline = self.replaceToken(newline, 
370                                        "[INCLUDE_FILE]", self.file) 
371            if self.foundCPP:
372                newline = self.replaceToken(newline, 
373                                            "[C_INCLUDE_FILE]", "") 
374                newline = self.replaceToken(newline, 
375                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
376            else: 
377                newline = self.replaceToken(newline, 
378                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
379                newline = self.replaceToken(newline, 
380                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
381               
382            # Numerical calcs dealloc
383            dealloc_str = "\n"
384            if self.modelCalcFlag:
385                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
386            newline = self.replaceToken(newline, 
387                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
388               
389            # Numerical calcs init
390            init_str = "\n"
391            if self.modelCalcFlag:
392                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
393            newline = self.replaceToken(newline, 
394                                        "[NUMERICAL_INIT]", init_str)     
395               
396            # Numerical calcs reset
397            reset_str = "\n"
398            if self.modelCalcFlag:
399                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
400            newline = self.replaceToken(newline, 
401                                        "[NUMERICAL_RESET]", reset_str)     
402               
403            # Setting dispsertion weights
404            set_weights = "    // Ugliness necessary to go from python to C\n"
405            set_weights = "    // TODO: refactor this\n"
406            for par in self.disp_params:
407                par = par.strip()
408                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
409                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
410                set_weights += "    } else"
411            newline = self.replaceToken(newline, 
412                                        "[SET_DISPERSION]", set_weights)     
413           
414            # Write new line to the wrapper .c file
415            file.write(newline+'\n')
416           
417           
418        file.close()
419       
420    def write_python_wrapper(self):
421        """ Writes the python file to create the python extension class
422            The file is written in ../[PYTHONCLASS].py
423        """
424        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
425        file = open(file_path, 'w')
426        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
427       
428        tmp_buf = template.read()
429        tmp_lines = tmp_buf.split('\n')
430       
431        for tmp_line in tmp_lines:
432           
433            # Catch class name
434            newline = self.replaceToken(tmp_line, 
435                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
436           
437            # Catch class name
438            newline = self.replaceToken(newline, 
439                                        "[PYTHONCLASS]", self.pythonClass)
440           
441            # Include file
442            newline = self.replaceToken(newline, 
443                                        "[INCLUDE_FILE]", self.file)   
444                   
445            # Include file
446            newline = self.replaceToken(newline, 
447                                        "[DEFAULT_LIST]", self.default_list)
448            # model description
449            newline = self.replaceToken(newline, 
450                                        "[DESCRIPTION]", self.description)
451            # Parameter details
452            newline = self.replaceToken(newline, 
453                                        "[PAR_DETAILS]", self.details)
454           
455            # fixed list  details
456            fixed_str = str(self.fixed)
457            fixed_str = fixed_str.replace(', ', ',\n                      ')
458            newline = self.replaceToken(newline, "[FIXED]",fixed_str)
459           
460            # non-fittable list details
461            pars_str = str(self.non_fittable)
462            pars_str = pars_str.replace(', ', 
463                                        ',\n                             ')
464            newline = self.replaceToken(newline, 
465                                        "[NON_FITTABLE_PARAMS]",
466                                        pars_str)
467           
468            ## parameters with orientation
469            oriented_str = str(self.orientation_params)
470            formatted_endl = ',\n                                   '
471            oriented_str = oriented_str.replace(', ', formatted_endl)
472            newline = self.replaceToken(newline, 
473                               "[ORIENTATION_PARAMS]", oriented_str)
474           
475            # Write new line to the wrapper .c file
476            file.write(newline+'\n')
477               
478        file.close()
479       
480       
481    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
482        """ Replace a token in the template file
483            @param line: line of text to inspect
484            @param key: token to look for
485            @param value: string value to replace the token with
486            @return: new string value
487        """
488        lenkey = len(key)
489        newline = line
490       
491        while newline.count(key)>0:
492            index = newline.index(key)
493            newline = newline[:index]+value+newline[index+lenkey:]
494       
495        return newline
496       
497       
498# main
499if __name__ == '__main__':
500    if len(sys.argv)>1:
501        print "Will look for file %s" % sys.argv[1]
502        app = WrapperGenerator(sys.argv[1])
503    else:
504        app = WrapperGenerator("test.h")
505    app.read()
506    app.write_c_wrapper()
507    app.write_python_wrapper()
508    print app
509   
510# End of file       
Note: See TracBrowser for help on using the repository browser.