source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ dac6869

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since dac6869 was 1b758b3, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

clean up warnings

  • Property mode set to 100644
File size: 20.0 KB
RevLine 
[af03ddd]1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
[95986b5]5import os, sys,re
[25a608f5]6import lineparser
[af03ddd]7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
[642a025]58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
[af03ddd]59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
[a1d1b90]70        self.foundCPP = False
[d62f422]71        self.inParDefs = False
[af03ddd]72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
[4e2f6ef8]84        #model description
85        self.description=''
[836fe6e]86        # paramaters for fittable
87        self.fixed= []
[35aface]88        # paramaters for non-fittable
89        self.non_fittable= []
[25a608f5]90        ## parameters with orientation
91        self.orientation_params =[]
[2d1b700]92        ## output directory for wrappers
93        self.output_dir = output_dir
[642a025]94        self.c_wrapper_dir = c_wrapper_dir
[25a608f5]95       
[af03ddd]96       
97    def __repr__(self):
98        """ Simple output for printing """
99       
[25a608f5]100        rep  = "\n Python class: %s\n\n" % self.pythonClass
101        rep += "  struc name: %s\n\n" % self.structName
102        rep += "  params:     %s\n\n" % self.params
103        rep += "  description:    %s\n\n" % self.description
104        rep += "  Fittable parameters:     %s\n\n"% self.fixed
[35aface]105        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
[25a608f5]106        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
[af03ddd]107        return rep
108       
109    def read(self):
110        """ Reads in the .h file to catch parameters of the wrapper """
111       
112        # Check if the file is there
113        if not os.path.isfile(self.file):
114            raise ValueError, "File %s is not a regular file" % self.file
115       
116        # Read file
117        f = open(self.file,'r')
118        buf = f.read()
119       
120        self.default_list = "List of default parameters:\n"
121        #lines = string.split(buf,'\n')
122        lines = buf.split('\n')
123        self.details  = "## Parameter details [units, min, max]\n"
124        self.details += "        self.details = {}\n"
[25a608f5]125       
[836fe6e]126        #open item in this case Fixed
127        text='text'
128        key2="<%s>"%text.lower()
129        # close an item in this case fixed
[da3dae3]130        text='TexT'
[836fe6e]131        key3="</%s>"%text.lower()
132       
[25a608f5]133        ## Catch fixed parameters
134        key = "[FIXED]"
135        try:
136            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
137        except:
138           raise   
[35aface]139        ## Catch non-fittable parameters parameters
140        key = "[NON_FITTABLE_PARAMS]"
141        try:
142            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
143        except:
144           raise   
145
[25a608f5]146        ## Catch parameters with orientation
147        key = "[ORIENTATION_PARAMS]"   
148        try:
149            self.orientation_params = lineparser.readhelper( lines,key, 
150                                                    key2,key3, file= self.file)
151        except:
152           raise 
153        ## Catch Description
[96672c0]154        key = "[DESCRIPTION]"
[25a608f5]155       
156        find_description = False
[96672c0]157        temp=""
158        for line in lines:
159            if line.count(key)>0 :
[9316609]160               
[96672c0]161                try:
[25a608f5]162                    find_description= True
[96672c0]163                    index = line.index(key)
164                    toks = line[index:].split("=",1 )
165                    temp=toks[1].lstrip().rstrip()
166                    text='text'
[9316609]167                    key2="<%s>"%text.lower()
[96672c0]168                    if re.match(key2,temp)!=None:
[25a608f5]169   
[9316609]170                        toks2=temp.split(key2,1)
171                        self.description=toks2[1]
172                        text='text'
173                        key2="</%s>"%text.lower()
174                        if re.search(key2,toks2[1])!=None:
175                            temp=toks2[1].split(key2,1)
176                            self.description=temp[0]
177                            break
[25a608f5]178                     
[96672c0]179                    else:
180                        self.description=temp
[9316609]181                        break
[96672c0]182                except:
183                     raise ValueError, "Could not parse file %s" % self.file
[25a608f5]184            elif find_description:
[9316609]185                text='text'
186                key2="</%s>"%text.lower()
187                if re.search(key2,line)!=None:
188                    tok=line.split(key2,1)
189                    temp=tok[0].split("//",1)
190                    self.description+=tok[1].lstrip().rstrip()
191                    break
192                else:
193                    if re.search("//",line)!=None:
194                        temp=line.split("//",1)
195                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
196                       
197                    else:
198                        self.description+='\n\t\t'+line.lstrip().rstrip()
199                   
200               
[96672c0]201               
[af03ddd]202        for line in lines:
203           
204            # Catch class name
205            key = "[PYTHONCLASS]"
206            if line.count(key)>0:
207                try:
208                    index = line.index(key)
209                    #toks = string.split( line[index:], "=" )
210                    toks = line[index:].split("=" )
211                    self.pythonClass = toks[1].lstrip().rstrip()
212                except:
213                    raise ValueError, "Could not parse file %s" % self.file
214               
215            # Catch struct name
[a1d1b90]216            # C++ class definition
217            if line.count("class")>0:
218                # We are entering a class definition
[d62f422]219                self.inParDefs = True
[a1d1b90]220                self.foundCPP = True
221               
222            # Old-Style C struct definition
[af03ddd]223            if line.count("typedef struct")>0:
224                # We are entering a struct block
[d62f422]225                self.inParDefs = True
[af03ddd]226                self.inStruct = True
227           
[d62f422]228            if self.inParDefs and line.count("}")>0:
[af03ddd]229                # We are exiting a struct block
[d62f422]230                self.inParDefs = False
[af03ddd]231               
[d62f422]232                if self.inStruct:
233                    self.inStruct = False
234                    # Catch the name of the struct
235                    index = line.index("}")
236                    toks = line[index+1:].split(";")
237                    # Catch pointer definition
238                    toks2 = toks[0].split(',')
239                    self.structName = toks2[0].lstrip().rstrip()
[96672c0]240           
[af03ddd]241            # Catch struct content
242            key = "[DEFAULT]"
[d62f422]243            if self.inParDefs and line.count(key)>0:
[af03ddd]244                # Found a new parameter
245                try:
246                    index = line.index(key)
247                    toks = line[index:].split("=")
248                    toks2 = toks[2].split()
249                    val = float(toks2[0])
250                    self.params[toks[1]] = val
251                    #self.pythonClass = toks[1].lstrip().rstrip()
252                    units = ""
253                    if len(toks2) >= 2:
254                        units = toks2[1]
255                    self.default_list += "         %-15s = %s %s\n" % \
256                        (toks[1], val, units)
257                   
258                    # Check for min and max
259                    min = "None"
260                    max = "None"
261                    if len(toks2) == 4:
262                        min = toks2[2]
263                        max = toks2[3]
264                   
265                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
266                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
267                except:
268                    raise ValueError, "Could not parse input file %s \n  %s" % \
269                        (self.file, sys.exc_value)
270               
271               
272            # Catch need for numerical calculations
273            key = "CalcParameters calcPars"
274            if line.count(key)>0:
275                self.modelCalcFlag = True
276               
277            # Catch list of dispersed parameters
278            key = "[DISP_PARAMS]"
279            if line.count(key)>0:
280                try:
281                    index = line.index(key)
282                    toks = line[index:].split("=")
283                    list_str = toks[1].lstrip().rstrip()
284                    self.disp_params = list_str.split(',')
285                except:
286                    raise ValueError, "Could not parse file %s" % self.file
287               
[95986b5]288       
[af03ddd]289               
290    def write_c_wrapper(self):
291        """ Writes the C file to create the python extension class
292            The file is written in C[PYTHONCLASS].c
293        """
[642a025]294        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
[2d1b700]295        file = open(file_path, 'w')
[af03ddd]296       
[2d1b700]297        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
[af03ddd]298       
299        tmp_buf = template.read()
300        #tmp_lines = string.split(tmp_buf,'\n')
301        tmp_lines = tmp_buf.split('\n')
302       
303        for tmp_line in tmp_lines:
304           
305            # Catch class name
306            newline = self.replaceToken(tmp_line, 
307                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
[9316609]308            #Catch model description
[95986b5]309            #newline = self.replaceToken(tmp_line,
310            #                            "[DESCRIPTION]", self.description)
[af03ddd]311            # Catch C model name
312            newline = self.replaceToken(newline, 
313                                        "[CMODEL]", self.pythonClass)
314           
315            # Catch class name
316            newline = self.replaceToken(newline, 
317                                        "[MODELSTRUCT]", self.structName)
318           
319            # Dictionary initialization
320            param_str = "// Initialize parameter dictionary\n"           
321            for par in self.params:
[35aface]322                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
[0f5bc9f]323                    (par, self.params[par])
[af03ddd]324
[1b758b3]325            if len(self.disp_params)>0:
326                param_str += "        // Initialize dispersion / averaging parameter dict\n"
327                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
328                param_str += "        PyObject * disp_dict;\n"
329                for par in self.disp_params:
330                    par = par.strip()
331                    param_str += "        disp_dict = PyDict_New();\n"
332                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
333                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
[af03ddd]334               
335            # Initialize dispersion object dictionnary
336            param_str += "\n"
337           
338               
339            newline = self.replaceToken(newline,
340                                        "[INITDICTIONARY]", param_str)
341           
342            # Read dictionary
343            param_str = "    // Reader parameter dictionary\n"
344            for par in self.params:
345                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
346                    (par, par)
347                   
[1b758b3]348            if len(self.disp_params)>0:
349                param_str += "    // Read in dispersion parameters\n"
350                param_str += "    PyObject* disp_dict;\n"
351                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
352                for par in self.disp_params:
353                    par = par.strip()
354                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
355                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
[af03ddd]356               
357            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
358               
359            # Name of .c file
360            #toks = string.split(self.file,'.')
361            basename = os.path.basename(self.file)
362            toks = basename.split('.')
363            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
364           
365            # Include file
366            basename = os.path.basename(self.file)
367            newline = self.replaceToken(newline, 
[a1d1b90]368                                        "[INCLUDE_FILE]", self.file) 
369            if self.foundCPP:
370                newline = self.replaceToken(newline, 
371                                            "[C_INCLUDE_FILE]", "") 
372                newline = self.replaceToken(newline, 
373                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
374            else: 
375                newline = self.replaceToken(newline, 
376                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
377                newline = self.replaceToken(newline, 
378                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
[af03ddd]379               
380            # Numerical calcs dealloc
381            dealloc_str = "\n"
382            if self.modelCalcFlag:
383                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
384            newline = self.replaceToken(newline, 
385                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
386               
387            # Numerical calcs init
388            init_str = "\n"
389            if self.modelCalcFlag:
390                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
391            newline = self.replaceToken(newline, 
392                                        "[NUMERICAL_INIT]", init_str)     
393               
394            # Numerical calcs reset
395            reset_str = "\n"
396            if self.modelCalcFlag:
397                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
398            newline = self.replaceToken(newline, 
399                                        "[NUMERICAL_RESET]", reset_str)     
400               
401            # Setting dispsertion weights
402            set_weights = "    // Ugliness necessary to go from python to C\n"
403            set_weights = "    // TODO: refactor this\n"
404            for par in self.disp_params:
405                par = par.strip()
406                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
407                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
408                set_weights += "    } else"
409            newline = self.replaceToken(newline, 
410                                        "[SET_DISPERSION]", set_weights)     
411           
412            # Write new line to the wrapper .c file
413            file.write(newline+'\n')
414           
415           
416        file.close()
417       
418    def write_python_wrapper(self):
419        """ Writes the python file to create the python extension class
420            The file is written in ../[PYTHONCLASS].py
421        """
[642a025]422        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
[2d1b700]423        file = open(file_path, 'w')
424        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
[af03ddd]425       
426        tmp_buf = template.read()
427        tmp_lines = tmp_buf.split('\n')
428       
429        for tmp_line in tmp_lines:
430           
431            # Catch class name
432            newline = self.replaceToken(tmp_line, 
433                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
434           
435            # Catch class name
436            newline = self.replaceToken(newline, 
437                                        "[PYTHONCLASS]", self.pythonClass)
438           
439            # Include file
440            newline = self.replaceToken(newline, 
441                                        "[INCLUDE_FILE]", self.file)   
442                   
443            # Include file
444            newline = self.replaceToken(newline, 
445                                        "[DEFAULT_LIST]", self.default_list)
[95986b5]446            # model description
[af03ddd]447            newline = self.replaceToken(newline, 
[95986b5]448                                        "[DESCRIPTION]", self.description)
[4e2f6ef8]449            # Parameter details
450            newline = self.replaceToken(newline, 
[95986b5]451                                        "[PAR_DETAILS]", self.details)
452           
[836fe6e]453            # fixed list  details
454            newline = self.replaceToken(newline, 
455                                        "[FIXED]",str(self.fixed))
[35aface]456            # non-fittable list  details
457            newline = self.replaceToken(newline, 
458                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
[25a608f5]459            ## parameters with orientation
460       
461            newline = self.replaceToken(newline, 
462                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
[836fe6e]463           
[af03ddd]464            # Write new line to the wrapper .c file
465            file.write(newline+'\n')
466               
467        file.close()
468       
469       
470    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
471        """ Replace a token in the template file
472            @param line: line of text to inspect
473            @param key: token to look for
474            @param value: string value to replace the token with
475            @return: new string value
476        """
477        lenkey = len(key)
478        newline = line
[836fe6e]479       
[af03ddd]480        while newline.count(key)>0:
481            index = newline.index(key)
482            newline = newline[:index]+value+newline[index+lenkey:]
[836fe6e]483       
[af03ddd]484        return newline
485       
486       
487# main
488if __name__ == '__main__':
489    if len(sys.argv)>1:
490        print "Will look for file %s" % sys.argv[1]
491        app = WrapperGenerator(sys.argv[1])
492    else:
493        app = WrapperGenerator("test.h")
494    app.read()
495    app.write_c_wrapper()
496    app.write_python_wrapper()
497    print app
498   
499# End of file       
Note: See TracBrowser for help on using the repository browser.