source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 98fdccd

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 98fdccd was 642a025, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

remove files that are generated by build

  • Property mode set to 100644
File size: 20.0 KB
RevLine 
[af03ddd]1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
[95986b5]5import os, sys,re
[25a608f5]6import lineparser
[af03ddd]7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
[642a025]58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
[af03ddd]59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
[a1d1b90]70        self.foundCPP = False
[d62f422]71        self.inParDefs = False
[af03ddd]72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
[4e2f6ef8]84        #model description
85        self.description=''
[836fe6e]86        # paramaters for fittable
87        self.fixed= []
[35aface]88        # paramaters for non-fittable
89        self.non_fittable= []
[25a608f5]90        ## parameters with orientation
91        self.orientation_params =[]
[2d1b700]92        ## output directory for wrappers
93        self.output_dir = output_dir
[642a025]94        self.c_wrapper_dir = c_wrapper_dir
[25a608f5]95       
[af03ddd]96       
97    def __repr__(self):
98        """ Simple output for printing """
99       
[25a608f5]100        rep  = "\n Python class: %s\n\n" % self.pythonClass
101        rep += "  struc name: %s\n\n" % self.structName
102        rep += "  params:     %s\n\n" % self.params
103        rep += "  description:    %s\n\n" % self.description
104        rep += "  Fittable parameters:     %s\n\n"% self.fixed
[35aface]105        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
[25a608f5]106        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
[af03ddd]107        return rep
108       
109    def read(self):
110        """ Reads in the .h file to catch parameters of the wrapper """
111       
112        # Check if the file is there
113        if not os.path.isfile(self.file):
114            raise ValueError, "File %s is not a regular file" % self.file
115       
116        # Read file
117        f = open(self.file,'r')
118        buf = f.read()
119       
120        self.default_list = "List of default parameters:\n"
121        #lines = string.split(buf,'\n')
122        lines = buf.split('\n')
123        self.details  = "## Parameter details [units, min, max]\n"
124        self.details += "        self.details = {}\n"
[25a608f5]125       
[836fe6e]126        #open item in this case Fixed
127        text='text'
128        key2="<%s>"%text.lower()
129        # close an item in this case fixed
[da3dae3]130        text='TexT'
[836fe6e]131        key3="</%s>"%text.lower()
132       
[25a608f5]133        ## Catch fixed parameters
134        key = "[FIXED]"
135        try:
136            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
137        except:
138           raise   
[35aface]139        ## Catch non-fittable parameters parameters
140        key = "[NON_FITTABLE_PARAMS]"
141        try:
142            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
143        except:
144           raise   
145
[25a608f5]146        ## Catch parameters with orientation
147        key = "[ORIENTATION_PARAMS]"   
148        try:
149            self.orientation_params = lineparser.readhelper( lines,key, 
150                                                    key2,key3, file= self.file)
151        except:
152           raise 
153        ## Catch Description
[96672c0]154        key = "[DESCRIPTION]"
[25a608f5]155       
156        find_description = False
[96672c0]157        temp=""
158        for line in lines:
159            if line.count(key)>0 :
[9316609]160               
[96672c0]161                try:
[25a608f5]162                    find_description= True
[96672c0]163                    index = line.index(key)
164                    toks = line[index:].split("=",1 )
165                    temp=toks[1].lstrip().rstrip()
166                    text='text'
[9316609]167                    key2="<%s>"%text.lower()
[96672c0]168                    if re.match(key2,temp)!=None:
[25a608f5]169   
[9316609]170                        toks2=temp.split(key2,1)
171                        self.description=toks2[1]
172                        text='text'
173                        key2="</%s>"%text.lower()
174                        if re.search(key2,toks2[1])!=None:
175                            temp=toks2[1].split(key2,1)
176                            self.description=temp[0]
177                            break
[25a608f5]178                     
[96672c0]179                    else:
180                        self.description=temp
[9316609]181                        break
[96672c0]182                except:
183                     raise ValueError, "Could not parse file %s" % self.file
[25a608f5]184            elif find_description:
[9316609]185                text='text'
186                key2="</%s>"%text.lower()
187                if re.search(key2,line)!=None:
188                    tok=line.split(key2,1)
189                    temp=tok[0].split("//",1)
190                    self.description+=tok[1].lstrip().rstrip()
191                    break
192                else:
193                    if re.search("//",line)!=None:
194                        temp=line.split("//",1)
195                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
196                       
197                    else:
198                        self.description+='\n\t\t'+line.lstrip().rstrip()
199                   
200               
[96672c0]201               
[af03ddd]202        for line in lines:
203           
204            # Catch class name
205            key = "[PYTHONCLASS]"
206            if line.count(key)>0:
207                try:
208                    index = line.index(key)
209                    #toks = string.split( line[index:], "=" )
210                    toks = line[index:].split("=" )
211                    self.pythonClass = toks[1].lstrip().rstrip()
212                except:
213                    raise ValueError, "Could not parse file %s" % self.file
214               
215            # Catch struct name
[a1d1b90]216            # C++ class definition
217            if line.count("class")>0:
218                # We are entering a class definition
[d62f422]219                self.inParDefs = True
[a1d1b90]220                self.foundCPP = True
221               
222            # Old-Style C struct definition
[af03ddd]223            if line.count("typedef struct")>0:
224                # We are entering a struct block
[d62f422]225                self.inParDefs = True
[af03ddd]226                self.inStruct = True
227           
[d62f422]228            if self.inParDefs and line.count("}")>0:
[af03ddd]229                # We are exiting a struct block
[d62f422]230                self.inParDefs = False
[af03ddd]231               
[d62f422]232                if self.inStruct:
233                    self.inStruct = False
234                    # Catch the name of the struct
235                    index = line.index("}")
236                    toks = line[index+1:].split(";")
237                    # Catch pointer definition
238                    toks2 = toks[0].split(',')
239                    self.structName = toks2[0].lstrip().rstrip()
[96672c0]240           
[af03ddd]241            # Catch struct content
242            key = "[DEFAULT]"
[d62f422]243            if self.inParDefs and line.count(key)>0:
[af03ddd]244                # Found a new parameter
245                try:
246                    index = line.index(key)
247                    toks = line[index:].split("=")
248                    toks2 = toks[2].split()
249                    val = float(toks2[0])
250                    self.params[toks[1]] = val
251                    #self.pythonClass = toks[1].lstrip().rstrip()
252                    units = ""
253                    if len(toks2) >= 2:
254                        units = toks2[1]
255                    self.default_list += "         %-15s = %s %s\n" % \
256                        (toks[1], val, units)
257                   
258                    # Check for min and max
259                    min = "None"
260                    max = "None"
261                    if len(toks2) == 4:
262                        min = toks2[2]
263                        max = toks2[3]
264                   
265                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
266                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
267                except:
268                    raise ValueError, "Could not parse input file %s \n  %s" % \
269                        (self.file, sys.exc_value)
270               
271               
272            # Catch need for numerical calculations
273            key = "CalcParameters calcPars"
274            if line.count(key)>0:
275                self.modelCalcFlag = True
276               
277            # Catch list of dispersed parameters
278            key = "[DISP_PARAMS]"
279            if line.count(key)>0:
280                try:
281                    index = line.index(key)
282                    toks = line[index:].split("=")
283                    list_str = toks[1].lstrip().rstrip()
284                    self.disp_params = list_str.split(',')
285                except:
286                    raise ValueError, "Could not parse file %s" % self.file
287               
[95986b5]288       
[af03ddd]289               
290    def write_c_wrapper(self):
291        """ Writes the C file to create the python extension class
292            The file is written in C[PYTHONCLASS].c
293        """
[642a025]294        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
[2d1b700]295        file = open(file_path, 'w')
[af03ddd]296       
[2d1b700]297        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
[af03ddd]298       
299        tmp_buf = template.read()
300        #tmp_lines = string.split(tmp_buf,'\n')
301        tmp_lines = tmp_buf.split('\n')
302       
303        for tmp_line in tmp_lines:
304           
305            # Catch class name
306            newline = self.replaceToken(tmp_line, 
307                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
[9316609]308            #Catch model description
[95986b5]309            #newline = self.replaceToken(tmp_line,
310            #                            "[DESCRIPTION]", self.description)
[af03ddd]311            # Catch C model name
312            newline = self.replaceToken(newline, 
313                                        "[CMODEL]", self.pythonClass)
314           
315            # Catch class name
316            newline = self.replaceToken(newline, 
317                                        "[MODELSTRUCT]", self.structName)
318           
319            # Dictionary initialization
320            param_str = "// Initialize parameter dictionary\n"           
321            for par in self.params:
[35aface]322                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
[0f5bc9f]323                    (par, self.params[par])
[af03ddd]324
325            param_str += "        // Initialize dispersion / averaging parameter dict\n"
326            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
327            param_str += "        PyObject * disp_dict;\n"
328            for par in self.disp_params:
329                par = par.strip()
330                param_str += "        disp_dict = PyDict_New();\n"
331                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
332                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
333               
334            # Initialize dispersion object dictionnary
335            param_str += "\n"
336           
337               
338            newline = self.replaceToken(newline,
339                                        "[INITDICTIONARY]", param_str)
340           
341            # Read dictionary
342            param_str = "    // Reader parameter dictionary\n"
343            for par in self.params:
344                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
345                    (par, par)
346                   
347            param_str += "    // Read in dispersion parameters\n"
348            param_str += "    PyObject* disp_dict;\n"
349            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
350            for par in self.disp_params:
351                par = par.strip()
352                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
353                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
354               
355            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
356               
357            # Name of .c file
358            #toks = string.split(self.file,'.')
359            basename = os.path.basename(self.file)
360            toks = basename.split('.')
361            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
362           
363            # Include file
364            basename = os.path.basename(self.file)
365            newline = self.replaceToken(newline, 
[a1d1b90]366                                        "[INCLUDE_FILE]", self.file) 
367            if self.foundCPP:
368                newline = self.replaceToken(newline, 
369                                            "[C_INCLUDE_FILE]", "") 
370                newline = self.replaceToken(newline, 
371                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
372            else: 
373                newline = self.replaceToken(newline, 
374                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
375                newline = self.replaceToken(newline, 
376                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
[af03ddd]377               
378            # Numerical calcs dealloc
379            dealloc_str = "\n"
380            if self.modelCalcFlag:
381                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
382            newline = self.replaceToken(newline, 
383                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
384               
385            # Numerical calcs init
386            init_str = "\n"
387            if self.modelCalcFlag:
388                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
389            newline = self.replaceToken(newline, 
390                                        "[NUMERICAL_INIT]", init_str)     
391               
392            # Numerical calcs reset
393            reset_str = "\n"
394            if self.modelCalcFlag:
395                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
396            newline = self.replaceToken(newline, 
397                                        "[NUMERICAL_RESET]", reset_str)     
398               
399            # Setting dispsertion weights
400            set_weights = "    // Ugliness necessary to go from python to C\n"
401            set_weights = "    // TODO: refactor this\n"
402            for par in self.disp_params:
403                par = par.strip()
404                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
405                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
406                set_weights += "    } else"
407            newline = self.replaceToken(newline, 
408                                        "[SET_DISPERSION]", set_weights)     
409           
410            # Write new line to the wrapper .c file
411            file.write(newline+'\n')
412           
413           
414        file.close()
415       
416    def write_python_wrapper(self):
417        """ Writes the python file to create the python extension class
418            The file is written in ../[PYTHONCLASS].py
419        """
[642a025]420        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
[2d1b700]421        file = open(file_path, 'w')
422        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
[af03ddd]423       
424        tmp_buf = template.read()
425        tmp_lines = tmp_buf.split('\n')
426       
427        for tmp_line in tmp_lines:
428           
429            # Catch class name
430            newline = self.replaceToken(tmp_line, 
431                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
432           
433            # Catch class name
434            newline = self.replaceToken(newline, 
435                                        "[PYTHONCLASS]", self.pythonClass)
436           
437            # Include file
438            newline = self.replaceToken(newline, 
439                                        "[INCLUDE_FILE]", self.file)   
440                   
441            # Include file
442            newline = self.replaceToken(newline, 
443                                        "[DEFAULT_LIST]", self.default_list)
[95986b5]444            # model description
[af03ddd]445            newline = self.replaceToken(newline, 
[95986b5]446                                        "[DESCRIPTION]", self.description)
[4e2f6ef8]447            # Parameter details
448            newline = self.replaceToken(newline, 
[95986b5]449                                        "[PAR_DETAILS]", self.details)
450           
[836fe6e]451            # fixed list  details
452            newline = self.replaceToken(newline, 
453                                        "[FIXED]",str(self.fixed))
[35aface]454            # non-fittable list  details
455            newline = self.replaceToken(newline, 
456                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
[25a608f5]457            ## parameters with orientation
458       
459            newline = self.replaceToken(newline, 
460                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
[836fe6e]461           
[af03ddd]462            # Write new line to the wrapper .c file
463            file.write(newline+'\n')
464               
465        file.close()
466       
467       
468    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
469        """ Replace a token in the template file
470            @param line: line of text to inspect
471            @param key: token to look for
472            @param value: string value to replace the token with
473            @return: new string value
474        """
475        lenkey = len(key)
476        newline = line
[836fe6e]477       
[af03ddd]478        while newline.count(key)>0:
479            index = newline.index(key)
480            newline = newline[:index]+value+newline[index+lenkey:]
[836fe6e]481       
[af03ddd]482        return newline
483       
484       
485# main
486if __name__ == '__main__':
487    if len(sys.argv)>1:
488        print "Will look for file %s" % sys.argv[1]
[95986b5]489    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
[af03ddd]490        app = WrapperGenerator(sys.argv[1])
491    else:
492        app = WrapperGenerator("test.h")
493    app.read()
494    app.write_c_wrapper()
495    app.write_python_wrapper()
496    print app
497   
498# End of file       
Note: See TracBrowser for help on using the repository browser.