source: sasview/sansmodels/src/sans/models/c_models/WrapperGenerator.py @ c09ac449

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c09ac449 was 25a608f5, checked in by Gervaise Alina <gervyh@…>, 16 years ago

add list of orientation _parameters to models

  • Property mode set to 100644
File size: 18.1 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        ## Name of struct for the c object
71        self.structName = None
72        ## Dictionary of parameters
73        self.params = {}
74        ## ModelCalculation module flag
75        self.modelCalcFlag = False
76        ## List of default parameters (text)
77        self.default_list = ""
78        ## Dictionary of units
79        self.details = ""
80        ## List of dispersed parameters
81        self.disp_params = []
82        #model description
83        self.description=''
84        # paramaters for fittable
85        self.fixed= []
86        ## parameters with orientation
87        self.orientation_params =[]
88       
89       
90    def __repr__(self):
91        """ Simple output for printing """
92       
93        rep  = "\n Python class: %s\n\n" % self.pythonClass
94        rep += "  struc name: %s\n\n" % self.structName
95        rep += "  params:     %s\n\n" % self.params
96        rep += "  description:    %s\n\n" % self.description
97        rep += "  Fittable parameters:     %s\n\n"% self.fixed
98        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
99        return rep
100       
101    def read(self):
102        """ Reads in the .h file to catch parameters of the wrapper """
103       
104        # Check if the file is there
105        if not os.path.isfile(self.file):
106            raise ValueError, "File %s is not a regular file" % self.file
107       
108        # Read file
109        f = open(self.file,'r')
110        buf = f.read()
111       
112        self.default_list = "List of default parameters:\n"
113        #lines = string.split(buf,'\n')
114        lines = buf.split('\n')
115        self.details  = "## Parameter details [units, min, max]\n"
116        self.details += "        self.details = {}\n"
117       
118        #open item in this case Fixed
119        text='text'
120        key2="<%s>"%text.lower()
121        # close an item in this case fixed
122        text='TexT'
123        key3="</%s>"%text.lower()
124       
125        ## Catch fixed parameters
126        key = "[FIXED]"
127        try:
128            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
129        except:
130           raise   
131        ## Catch parameters with orientation
132        key = "[ORIENTATION_PARAMS]"   
133        try:
134            self.orientation_params = lineparser.readhelper( lines,key, 
135                                                    key2,key3, file= self.file)
136        except:
137           raise 
138        ## Catch Description
139        key = "[DESCRIPTION]"
140       
141        find_description = False
142        temp=""
143        for line in lines:
144            if line.count(key)>0 :
145               
146                try:
147                    find_description= True
148                    index = line.index(key)
149                    toks = line[index:].split("=",1 )
150                    temp=toks[1].lstrip().rstrip()
151                    text='text'
152                    key2="<%s>"%text.lower()
153                    if re.match(key2,temp)!=None:
154   
155                        toks2=temp.split(key2,1)
156                        self.description=toks2[1]
157                        text='text'
158                        key2="</%s>"%text.lower()
159                        if re.search(key2,toks2[1])!=None:
160                            temp=toks2[1].split(key2,1)
161                            self.description=temp[0]
162                            break
163                     
164                    else:
165                        self.description=temp
166                        break
167                except:
168                     raise ValueError, "Could not parse file %s" % self.file
169            elif find_description:
170                text='text'
171                key2="</%s>"%text.lower()
172                if re.search(key2,line)!=None:
173                    tok=line.split(key2,1)
174                    temp=tok[0].split("//",1)
175                    self.description+=tok[1].lstrip().rstrip()
176                    break
177                else:
178                    if re.search("//",line)!=None:
179                        temp=line.split("//",1)
180                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
181                       
182                    else:
183                        self.description+='\n\t\t'+line.lstrip().rstrip()
184                   
185               
186               
187        for line in lines:
188           
189            # Catch class name
190            key = "[PYTHONCLASS]"
191            if line.count(key)>0:
192                try:
193                    index = line.index(key)
194                    #toks = string.split( line[index:], "=" )
195                    toks = line[index:].split("=" )
196                    self.pythonClass = toks[1].lstrip().rstrip()
197                except:
198                    raise ValueError, "Could not parse file %s" % self.file
199               
200            # Catch struct name
201            if line.count("typedef struct")>0:
202                # We are entering a struct block
203                self.inStruct = True
204           
205            if self.inStruct and line.count("}")>0:
206                # We are exiting a struct block
207                self.inStruct = False
208   
209                # Catch the name of the struct
210                index = line.index("}")
211                #toks = string.split(line[index+1:],";")
212                toks = line[index+1:].split(";")
213                # Catch pointer definition
214                #toks2 = string.split(toks[0],',')
215                toks2 = toks[0].split(',')
216                self.structName = toks2[0].lstrip().rstrip()
217           
218               
219           
220            # Catch struct content
221            key = "[DEFAULT]"
222            if self.inStruct and line.count(key)>0:
223                # Found a new parameter
224                try:
225                    index = line.index(key)
226                    toks = line[index:].split("=")
227                    toks2 = toks[2].split()
228                    val = float(toks2[0])
229                    self.params[toks[1]] = val
230                    #self.pythonClass = toks[1].lstrip().rstrip()
231                    units = ""
232                    if len(toks2) >= 2:
233                        units = toks2[1]
234                    self.default_list += "         %-15s = %s %s\n" % \
235                        (toks[1], val, units)
236                   
237                    # Check for min and max
238                    min = "None"
239                    max = "None"
240                    if len(toks2) == 4:
241                        min = toks2[2]
242                        max = toks2[3]
243                   
244                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
245                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
246                except:
247                    raise ValueError, "Could not parse input file %s \n  %s" % \
248                        (self.file, sys.exc_value)
249               
250               
251            # Catch need for numerical calculations
252            key = "CalcParameters calcPars"
253            if line.count(key)>0:
254                self.modelCalcFlag = True
255               
256            # Catch list of dispersed parameters
257            key = "[DISP_PARAMS]"
258            if line.count(key)>0:
259                try:
260                    index = line.index(key)
261                    toks = line[index:].split("=")
262                    list_str = toks[1].lstrip().rstrip()
263                    self.disp_params = list_str.split(',')
264                except:
265                    raise ValueError, "Could not parse file %s" % self.file
266               
267       
268               
269    def write_c_wrapper(self):
270        """ Writes the C file to create the python extension class
271            The file is written in C[PYTHONCLASS].c
272        """
273       
274        file = open("C"+self.pythonClass+'.cpp', 'w')
275        template = open("classTemplate.txt", 'r')
276       
277        tmp_buf = template.read()
278        #tmp_lines = string.split(tmp_buf,'\n')
279        tmp_lines = tmp_buf.split('\n')
280       
281        for tmp_line in tmp_lines:
282           
283            # Catch class name
284            newline = self.replaceToken(tmp_line, 
285                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
286            #Catch model description
287            #newline = self.replaceToken(tmp_line,
288            #                            "[DESCRIPTION]", self.description)
289            # Catch C model name
290            newline = self.replaceToken(newline, 
291                                        "[CMODEL]", self.pythonClass)
292           
293            # Catch class name
294            newline = self.replaceToken(newline, 
295                                        "[MODELSTRUCT]", self.structName)
296           
297            # Dictionary initialization
298            param_str = "// Initialize parameter dictionary\n"           
299            for par in self.params:
300                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%f));\n" % \
301                    (par, self.params[par])
302
303            param_str += "        // Initialize dispersion / averaging parameter dict\n"
304            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
305            param_str += "        PyObject * disp_dict;\n"
306            for par in self.disp_params:
307                par = par.strip()
308                param_str += "        disp_dict = PyDict_New();\n"
309                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
310                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
311               
312            # Initialize dispersion object dictionnary
313            param_str += "\n"
314           
315               
316            newline = self.replaceToken(newline,
317                                        "[INITDICTIONARY]", param_str)
318           
319            # Read dictionary
320            param_str = "    // Reader parameter dictionary\n"
321            for par in self.params:
322                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
323                    (par, par)
324                   
325            param_str += "    // Read in dispersion parameters\n"
326            param_str += "    PyObject* disp_dict;\n"
327            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
328            for par in self.disp_params:
329                par = par.strip()
330                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
331                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
332               
333            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
334               
335            # Name of .c file
336            #toks = string.split(self.file,'.')
337            basename = os.path.basename(self.file)
338            toks = basename.split('.')
339            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
340           
341            # Include file
342            basename = os.path.basename(self.file)
343            newline = self.replaceToken(newline, 
344                                        "[INCLUDE_FILE]", basename)           
345               
346            # Numerical calcs dealloc
347            dealloc_str = "\n"
348            if self.modelCalcFlag:
349                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
350            newline = self.replaceToken(newline, 
351                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
352               
353            # Numerical calcs init
354            init_str = "\n"
355            if self.modelCalcFlag:
356                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
357            newline = self.replaceToken(newline, 
358                                        "[NUMERICAL_INIT]", init_str)     
359               
360            # Numerical calcs reset
361            reset_str = "\n"
362            if self.modelCalcFlag:
363                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
364            newline = self.replaceToken(newline, 
365                                        "[NUMERICAL_RESET]", reset_str)     
366               
367            # Setting dispsertion weights
368            set_weights = "    // Ugliness necessary to go from python to C\n"
369            set_weights = "    // TODO: refactor this\n"
370            for par in self.disp_params:
371                par = par.strip()
372                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
373                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
374                set_weights += "    } else"
375            newline = self.replaceToken(newline, 
376                                        "[SET_DISPERSION]", set_weights)     
377           
378            # Write new line to the wrapper .c file
379            file.write(newline+'\n')
380           
381           
382        file.close()
383       
384    def write_python_wrapper(self):
385        """ Writes the python file to create the python extension class
386            The file is written in ../[PYTHONCLASS].py
387        """
388       
389        file = open("../"+self.pythonClass+'.py', 'w')
390        template = open("modelTemplate.txt", 'r')
391       
392        tmp_buf = template.read()
393        tmp_lines = tmp_buf.split('\n')
394       
395        for tmp_line in tmp_lines:
396           
397            # Catch class name
398            newline = self.replaceToken(tmp_line, 
399                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
400           
401            # Catch class name
402            newline = self.replaceToken(newline, 
403                                        "[PYTHONCLASS]", self.pythonClass)
404           
405            # Include file
406            newline = self.replaceToken(newline, 
407                                        "[INCLUDE_FILE]", self.file)   
408                   
409            # Include file
410            newline = self.replaceToken(newline, 
411                                        "[DEFAULT_LIST]", self.default_list)
412            # model description
413            newline = self.replaceToken(newline, 
414                                        "[DESCRIPTION]", self.description)
415            # Parameter details
416            newline = self.replaceToken(newline, 
417                                        "[PAR_DETAILS]", self.details)
418           
419            # fixed list  details
420            newline = self.replaceToken(newline, 
421                                        "[FIXED]",str(self.fixed))
422            ## parameters with orientation
423       
424            newline = self.replaceToken(newline, 
425                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
426           
427            # Write new line to the wrapper .c file
428            file.write(newline+'\n')
429               
430        file.close()
431       
432       
433    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
434        """ Replace a token in the template file
435            @param line: line of text to inspect
436            @param key: token to look for
437            @param value: string value to replace the token with
438            @return: new string value
439        """
440        lenkey = len(key)
441        newline = line
442       
443        while newline.count(key)>0:
444            index = newline.index(key)
445            newline = newline[:index]+value+newline[index+lenkey:]
446       
447        return newline
448       
449       
450# main
451if __name__ == '__main__':
452    if len(sys.argv)>1:
453        print "Will look for file %s" % sys.argv[1]
454    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
455        app = WrapperGenerator(sys.argv[1])
456    else:
457        app = WrapperGenerator("test.h")
458    app.read()
459    app.write_c_wrapper()
460    app.write_python_wrapper()
461    print app
462   
463# End of file       
Note: See TracBrowser for help on using the repository browser.