source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 0a9686d

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 0a9686d was 0a9686d, checked in by Kieran Campbell <kieranrcampbell@…>, 12 years ago

Updated python wrapping to automatically generate c_models.cpp

  • Property mode set to 100644
File size: 20.5 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        ## output directory for wrappers
93        self.output_dir = output_dir
94        self.c_wrapper_dir = c_wrapper_dir
95
96       
97       
98    def __repr__(self):
99        """ Simple output for printing """
100       
101        rep  = "\n Python class: %s\n\n" % self.pythonClass
102        rep += "  struc name: %s\n\n" % self.structName
103        rep += "  params:     %s\n\n" % self.params
104        rep += "  description:    %s\n\n" % self.description
105        rep += "  Fittable parameters:     %s\n\n"% self.fixed
106        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
107        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
108        return rep
109       
110    def read(self):
111        """ Reads in the .h file to catch parameters of the wrapper """
112       
113        # Check if the file is there
114        if not os.path.isfile(self.file):
115            raise ValueError, "File %s is not a regular file" % self.file
116       
117        # Read file
118        f = open(self.file,'r')
119        buf = f.read()
120       
121        self.default_list = "List of default parameters:\n"
122        #lines = string.split(buf,'\n')
123        lines = buf.split('\n')
124        self.details  = "## Parameter details [units, min, max]\n"
125        self.details += "        self.details = {}\n"
126       
127        #open item in this case Fixed
128        text='text'
129        key2="<%s>"%text.lower()
130        # close an item in this case fixed
131        text='TexT'
132        key3="</%s>"%text.lower()
133       
134        ## Catch fixed parameters
135        key = "[FIXED]"
136        try:
137            self.fixed= lineparser.readhelper(lines, key, 
138                                              key2, key3, file=self.file)
139        except:
140           raise   
141        ## Catch non-fittable parameters parameters
142        key = "[NON_FITTABLE_PARAMS]"
143        try:
144            self.non_fittable= lineparser.readhelper(lines, key, key2,
145                                                     key3, file=self.file)
146        except:
147           raise   
148
149        ## Catch parameters with orientation
150        key = "[ORIENTATION_PARAMS]"   
151        try:
152            self.orientation_params = lineparser.readhelper(lines, key, 
153                                                    key2, key3, file=self.file)
154        except:
155           raise 
156        ## Catch Description
157        key = "[DESCRIPTION]"
158       
159        find_description = False
160        temp=""
161        for line in lines:
162            if line.count(key)>0 :
163               
164                try:
165                    find_description= True
166                    index = line.index(key)
167                    toks = line[index:].split("=",1 )
168                    temp=toks[1].lstrip().rstrip()
169                    text='text'
170                    key2="<%s>"%text.lower()
171                    if re.match(key2,temp)!=None:
172   
173                        toks2=temp.split(key2,1)
174                        self.description=toks2[1]
175                        text='text'
176                        key2="</%s>"%text.lower()
177                        if re.search(key2,toks2[1])!=None:
178                            temp=toks2[1].split(key2,1)
179                            self.description=temp[0]
180                            break
181                     
182                    else:
183                        self.description=temp
184                        break
185                except:
186                     raise ValueError, "Could not parse file %s" % self.file
187            elif find_description:
188                text='text'
189                key2="</%s>"%text.lower()
190                if re.search(key2,line)!=None:
191                    tok=line.split(key2,1)
192                    temp=tok[0].split("//",1)
193                    self.description+=tok[1].lstrip().rstrip()
194                    break
195                else:
196                    if re.search("//",line)!=None:
197                        temp=line.split("//",1)
198                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
199                       
200                    else:
201                        self.description+='\n\t\t'+line.lstrip().rstrip()
202                   
203               
204     
205        for line in lines:
206           
207            # Catch class name
208            key = "[PYTHONCLASS]"
209            if line.count(key)>0:
210                try:
211                    index = line.index(key)
212                    toks = line[index:].split("=" )
213                    self.pythonClass = toks[1].lstrip().rstrip()
214
215                except:
216                    raise ValueError, "Could not parse file %s" % self.file
217               
218            # Catch struct name
219            # C++ class definition
220            if line.count("class")>0:
221                # We are entering a class definition
222                self.inParDefs = True
223                self.foundCPP = True
224               
225            # Old-Style C struct definition
226            if line.count("typedef struct")>0:
227                # We are entering a struct block
228                self.inParDefs = True
229                self.inStruct = True
230           
231            if self.inParDefs and line.count("}")>0:
232                # We are exiting a struct block
233                self.inParDefs = False
234               
235                if self.inStruct:
236                    self.inStruct = False
237                    # Catch the name of the struct
238                    index = line.index("}")
239                    toks = line[index+1:].split(";")
240                    # Catch pointer definition
241                    toks2 = toks[0].split(',')
242                    self.structName = toks2[0].lstrip().rstrip()
243           
244            # Catch struct content
245            key = "[DEFAULT]"
246            if self.inParDefs and line.count(key)>0:
247                # Found a new parameter
248                try:
249                    index = line.index(key)
250                    toks = line[index:].split("=")
251                    toks2 = toks[2].split()
252                    val = float(toks2[0])
253                    self.params[toks[1]] = val
254                    #self.pythonClass = toks[1].lstrip().rstrip()
255                    units = ""
256                    if len(toks2) >= 2:
257                        units = toks2[1]
258                    self.default_list += "         %-15s = %s %s\n" % \
259                        (toks[1], val, units)
260                   
261                    # Check for min and max
262                    min = "None"
263                    max = "None"
264                    if len(toks2) == 4:
265                        min = toks2[2]
266                        max = toks2[3]
267                   
268                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
269                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
270                except:
271                    raise ValueError, "Could not parse input file %s \n  %s" % \
272                        (self.file, sys.exc_value)
273               
274               
275            # Catch need for numerical calculations
276            key = "CalcParameters calcPars"
277            if line.count(key)>0:
278                self.modelCalcFlag = True
279               
280            # Catch list of dispersed parameters
281            key = "[DISP_PARAMS]"
282            if line.count(key)>0:
283                try:
284                    index = line.index(key)
285                    toks = line[index:].split("=")
286                    list_str = toks[1].lstrip().rstrip()
287                    self.disp_params = list_str.split(',')
288                except:
289                    raise ValueError, "Could not parse file %s" % self.file
290
291    def write_c_wrapper(self):
292        """ Writes the C file to create the python extension class
293            The file is written in C[PYTHONCLASS].c
294        """
295        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
296        file = open(file_path, 'w')
297       
298        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
299       
300        tmp_buf = template.read()
301        #tmp_lines = string.split(tmp_buf,'\n')
302        tmp_lines = tmp_buf.split('\n')
303       
304        for tmp_line in tmp_lines:
305           
306            # Catch class name
307            newline = self.replaceToken(tmp_line, 
308                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
309            #Catch model description
310            #newline = self.replaceToken(tmp_line,
311            #                            "[DESCRIPTION]", self.description)
312            # Catch C model name
313            newline = self.replaceToken(newline, 
314                                        "[CMODEL]", self.pythonClass)
315           
316            # Catch class name
317            newline = self.replaceToken(newline, 
318                                        "[MODELSTRUCT]", self.structName)
319           
320            # Dictionary initialization
321            param_str = "// Initialize parameter dictionary\n"           
322            for par in self.params:
323                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
324                    (par, self.params[par])
325
326            if len(self.disp_params)>0:
327                param_str += "        // Initialize dispersion / averaging parameter dict\n"
328                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
329                param_str += "        PyObject * disp_dict;\n"
330                for par in self.disp_params:
331                    par = par.strip()
332                    param_str += "        disp_dict = PyDict_New();\n"
333                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
334                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
335               
336            # Initialize dispersion object dictionnary
337            param_str += "\n"
338           
339               
340            newline = self.replaceToken(newline,
341                                        "[INITDICTIONARY]", param_str)
342           
343            # Read dictionary
344            param_str = "    // Reader parameter dictionary\n"
345            for par in self.params:
346                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
347                    (par, par)
348                   
349            if len(self.disp_params)>0:
350                param_str += "    // Read in dispersion parameters\n"
351                param_str += "    PyObject* disp_dict;\n"
352                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
353                for par in self.disp_params:
354                    par = par.strip()
355                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
356                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
357               
358            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
359               
360            # Name of .c file
361            #toks = string.split(self.file,'.')
362            basename = os.path.basename(self.file)
363            toks = basename.split('.')
364            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
365           
366            # Include file
367            basename = os.path.basename(self.file)
368            newline = self.replaceToken(newline, 
369                                        "[INCLUDE_FILE]", self.file) 
370            if self.foundCPP:
371                newline = self.replaceToken(newline, 
372                                            "[C_INCLUDE_FILE]", "") 
373                newline = self.replaceToken(newline, 
374                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
375            else: 
376                newline = self.replaceToken(newline, 
377                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
378                newline = self.replaceToken(newline, 
379                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
380               
381            # Numerical calcs dealloc
382            dealloc_str = "\n"
383            if self.modelCalcFlag:
384                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
385            newline = self.replaceToken(newline, 
386                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
387               
388            # Numerical calcs init
389            init_str = "\n"
390            if self.modelCalcFlag:
391                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
392            newline = self.replaceToken(newline, 
393                                        "[NUMERICAL_INIT]", init_str)     
394               
395            # Numerical calcs reset
396            reset_str = "\n"
397            if self.modelCalcFlag:
398                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
399            newline = self.replaceToken(newline, 
400                                        "[NUMERICAL_RESET]", reset_str)     
401               
402            # Setting dispsertion weights
403            set_weights = "    // Ugliness necessary to go from python to C\n"
404            set_weights = "    // TODO: refactor this\n"
405            for par in self.disp_params:
406                par = par.strip()
407                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
408                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
409                set_weights += "    } else"
410            newline = self.replaceToken(newline, 
411                                        "[SET_DISPERSION]", set_weights)     
412           
413            # Write new line to the wrapper .c file
414            file.write(newline+'\n')
415           
416           
417        file.close()
418       
419    def write_python_wrapper(self):
420        """ Writes the python file to create the python extension class
421            The file is written in ../[PYTHONCLASS].py
422        """
423        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
424        file = open(file_path, 'w')
425        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
426       
427        tmp_buf = template.read()
428        tmp_lines = tmp_buf.split('\n')
429       
430        for tmp_line in tmp_lines:
431           
432            # Catch class name
433            newline = self.replaceToken(tmp_line, 
434                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
435           
436            # Catch class name
437            newline = self.replaceToken(newline, 
438                                        "[PYTHONCLASS]", self.pythonClass)
439           
440            # Include file
441            newline = self.replaceToken(newline, 
442                                        "[INCLUDE_FILE]", self.file)   
443                   
444            # Include file
445            newline = self.replaceToken(newline, 
446                                        "[DEFAULT_LIST]", self.default_list)
447            # model description
448            newline = self.replaceToken(newline, 
449                                        "[DESCRIPTION]", self.description)
450            # Parameter details
451            newline = self.replaceToken(newline, 
452                                        "[PAR_DETAILS]", self.details)
453           
454            # fixed list  details
455            fixed_str = str(self.fixed)
456            fixed_str = fixed_str.replace(', ', ',\n                      ')
457            newline = self.replaceToken(newline, "[FIXED]",fixed_str)
458           
459            # non-fittable list details
460            pars_str = str(self.non_fittable)
461            pars_str = pars_str.replace(', ', 
462                                        ',\n                             ')
463            newline = self.replaceToken(newline, 
464                                        "[NON_FITTABLE_PARAMS]",
465                                        pars_str)
466           
467            ## parameters with orientation
468            oriented_str = str(self.orientation_params)
469            formatted_endl = ',\n                                   '
470            oriented_str = oriented_str.replace(', ', formatted_endl)
471            newline = self.replaceToken(newline, 
472                               "[ORIENTATION_PARAMS]", oriented_str)
473           
474            # Write new line to the wrapper .c file
475            file.write(newline+'\n')
476               
477        file.close()
478       
479       
480    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
481        """ Replace a token in the template file
482            @param line: line of text to inspect
483            @param key: token to look for
484            @param value: string value to replace the token with
485            @return: new string value
486        """
487        lenkey = len(key)
488        newline = line
489       
490        while newline.count(key)>0:
491            index = newline.index(key)
492            newline = newline[:index]+value+newline[index+lenkey:]
493       
494        return newline
495       
496    def getModelName(self):
497        return self.pythonClass
498       
499
500
501# main
502if __name__ == '__main__':
503    if len(sys.argv)>1:
504        print "Will look for file %s" % sys.argv[1]
505        app = WrapperGenerator(sys.argv[1])
506    else:
507        app = WrapperGenerator("test.h")
508    app.read()
509    app.write_c_wrapper()
510    app.write_python_wrapper()
511    print app
512   
513# End of file       
Note: See TracBrowser for help on using the repository browser.