source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ dd60b45

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since dd60b45 was 2d1b700, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

refactor refl model and auto-generate c++ wrapper at compile time.

  • Property mode set to 100644
File size: 19.9 KB
RevLine 
[af03ddd]1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
[95986b5]5import os, sys,re
[25a608f5]6import lineparser
[af03ddd]7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
[2d1b700]58    def __init__(self, filename, output_dir='.'):
[af03ddd]59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
[a1d1b90]70        self.foundCPP = False
[d62f422]71        self.inParDefs = False
[af03ddd]72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
[4e2f6ef8]84        #model description
85        self.description=''
[836fe6e]86        # paramaters for fittable
87        self.fixed= []
[35aface]88        # paramaters for non-fittable
89        self.non_fittable= []
[25a608f5]90        ## parameters with orientation
91        self.orientation_params =[]
[2d1b700]92        ## output directory for wrappers
93        self.output_dir = output_dir
[25a608f5]94       
[af03ddd]95       
96    def __repr__(self):
97        """ Simple output for printing """
98       
[25a608f5]99        rep  = "\n Python class: %s\n\n" % self.pythonClass
100        rep += "  struc name: %s\n\n" % self.structName
101        rep += "  params:     %s\n\n" % self.params
102        rep += "  description:    %s\n\n" % self.description
103        rep += "  Fittable parameters:     %s\n\n"% self.fixed
[35aface]104        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
[25a608f5]105        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
[af03ddd]106        return rep
107       
108    def read(self):
109        """ Reads in the .h file to catch parameters of the wrapper """
110       
111        # Check if the file is there
112        if not os.path.isfile(self.file):
113            raise ValueError, "File %s is not a regular file" % self.file
114       
115        # Read file
116        f = open(self.file,'r')
117        buf = f.read()
118       
119        self.default_list = "List of default parameters:\n"
120        #lines = string.split(buf,'\n')
121        lines = buf.split('\n')
122        self.details  = "## Parameter details [units, min, max]\n"
123        self.details += "        self.details = {}\n"
[25a608f5]124       
[836fe6e]125        #open item in this case Fixed
126        text='text'
127        key2="<%s>"%text.lower()
128        # close an item in this case fixed
[da3dae3]129        text='TexT'
[836fe6e]130        key3="</%s>"%text.lower()
131       
[25a608f5]132        ## Catch fixed parameters
133        key = "[FIXED]"
134        try:
135            self.fixed= lineparser.readhelper(lines,key, key2,key3, file= self.file)
136        except:
137           raise   
[35aface]138        ## Catch non-fittable parameters parameters
139        key = "[NON_FITTABLE_PARAMS]"
140        try:
141            self.non_fittable= lineparser.readhelper(lines,key, key2,key3, file= self.file)
142        except:
143           raise   
144
[25a608f5]145        ## Catch parameters with orientation
146        key = "[ORIENTATION_PARAMS]"   
147        try:
148            self.orientation_params = lineparser.readhelper( lines,key, 
149                                                    key2,key3, file= self.file)
150        except:
151           raise 
152        ## Catch Description
[96672c0]153        key = "[DESCRIPTION]"
[25a608f5]154       
155        find_description = False
[96672c0]156        temp=""
157        for line in lines:
158            if line.count(key)>0 :
[9316609]159               
[96672c0]160                try:
[25a608f5]161                    find_description= True
[96672c0]162                    index = line.index(key)
163                    toks = line[index:].split("=",1 )
164                    temp=toks[1].lstrip().rstrip()
165                    text='text'
[9316609]166                    key2="<%s>"%text.lower()
[96672c0]167                    if re.match(key2,temp)!=None:
[25a608f5]168   
[9316609]169                        toks2=temp.split(key2,1)
170                        self.description=toks2[1]
171                        text='text'
172                        key2="</%s>"%text.lower()
173                        if re.search(key2,toks2[1])!=None:
174                            temp=toks2[1].split(key2,1)
175                            self.description=temp[0]
176                            break
[25a608f5]177                     
[96672c0]178                    else:
179                        self.description=temp
[9316609]180                        break
[96672c0]181                except:
182                     raise ValueError, "Could not parse file %s" % self.file
[25a608f5]183            elif find_description:
[9316609]184                text='text'
185                key2="</%s>"%text.lower()
186                if re.search(key2,line)!=None:
187                    tok=line.split(key2,1)
188                    temp=tok[0].split("//",1)
189                    self.description+=tok[1].lstrip().rstrip()
190                    break
191                else:
192                    if re.search("//",line)!=None:
193                        temp=line.split("//",1)
194                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
195                       
196                    else:
197                        self.description+='\n\t\t'+line.lstrip().rstrip()
198                   
199               
[96672c0]200               
[af03ddd]201        for line in lines:
202           
203            # Catch class name
204            key = "[PYTHONCLASS]"
205            if line.count(key)>0:
206                try:
207                    index = line.index(key)
208                    #toks = string.split( line[index:], "=" )
209                    toks = line[index:].split("=" )
210                    self.pythonClass = toks[1].lstrip().rstrip()
211                except:
212                    raise ValueError, "Could not parse file %s" % self.file
213               
214            # Catch struct name
[a1d1b90]215            # C++ class definition
216            if line.count("class")>0:
217                # We are entering a class definition
[d62f422]218                self.inParDefs = True
[a1d1b90]219                self.foundCPP = True
220               
221            # Old-Style C struct definition
[af03ddd]222            if line.count("typedef struct")>0:
223                # We are entering a struct block
[d62f422]224                self.inParDefs = True
[af03ddd]225                self.inStruct = True
226           
[d62f422]227            if self.inParDefs and line.count("}")>0:
[af03ddd]228                # We are exiting a struct block
[d62f422]229                self.inParDefs = False
[af03ddd]230               
[d62f422]231                if self.inStruct:
232                    self.inStruct = False
233                    # Catch the name of the struct
234                    index = line.index("}")
235                    toks = line[index+1:].split(";")
236                    # Catch pointer definition
237                    toks2 = toks[0].split(',')
238                    self.structName = toks2[0].lstrip().rstrip()
[96672c0]239           
[af03ddd]240            # Catch struct content
241            key = "[DEFAULT]"
[d62f422]242            if self.inParDefs and line.count(key)>0:
[af03ddd]243                # Found a new parameter
244                try:
245                    index = line.index(key)
246                    toks = line[index:].split("=")
247                    toks2 = toks[2].split()
248                    val = float(toks2[0])
249                    self.params[toks[1]] = val
250                    #self.pythonClass = toks[1].lstrip().rstrip()
251                    units = ""
252                    if len(toks2) >= 2:
253                        units = toks2[1]
254                    self.default_list += "         %-15s = %s %s\n" % \
255                        (toks[1], val, units)
256                   
257                    # Check for min and max
258                    min = "None"
259                    max = "None"
260                    if len(toks2) == 4:
261                        min = toks2[2]
262                        max = toks2[3]
263                   
264                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
265                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
266                except:
267                    raise ValueError, "Could not parse input file %s \n  %s" % \
268                        (self.file, sys.exc_value)
269               
270               
271            # Catch need for numerical calculations
272            key = "CalcParameters calcPars"
273            if line.count(key)>0:
274                self.modelCalcFlag = True
275               
276            # Catch list of dispersed parameters
277            key = "[DISP_PARAMS]"
278            if line.count(key)>0:
279                try:
280                    index = line.index(key)
281                    toks = line[index:].split("=")
282                    list_str = toks[1].lstrip().rstrip()
283                    self.disp_params = list_str.split(',')
284                except:
285                    raise ValueError, "Could not parse file %s" % self.file
286               
[95986b5]287       
[af03ddd]288               
289    def write_c_wrapper(self):
290        """ Writes the C file to create the python extension class
291            The file is written in C[PYTHONCLASS].c
292        """
[2d1b700]293        file_path = os.path.join(self.output_dir, "C"+self.pythonClass+'.cpp')
294        file = open(file_path, 'w')
[af03ddd]295       
[2d1b700]296        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
[af03ddd]297       
298        tmp_buf = template.read()
299        #tmp_lines = string.split(tmp_buf,'\n')
300        tmp_lines = tmp_buf.split('\n')
301       
302        for tmp_line in tmp_lines:
303           
304            # Catch class name
305            newline = self.replaceToken(tmp_line, 
306                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
[9316609]307            #Catch model description
[95986b5]308            #newline = self.replaceToken(tmp_line,
309            #                            "[DESCRIPTION]", self.description)
[af03ddd]310            # Catch C model name
311            newline = self.replaceToken(newline, 
312                                        "[CMODEL]", self.pythonClass)
313           
314            # Catch class name
315            newline = self.replaceToken(newline, 
316                                        "[MODELSTRUCT]", self.structName)
317           
318            # Dictionary initialization
319            param_str = "// Initialize parameter dictionary\n"           
320            for par in self.params:
[35aface]321                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
[0f5bc9f]322                    (par, self.params[par])
[af03ddd]323
324            param_str += "        // Initialize dispersion / averaging parameter dict\n"
325            param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
326            param_str += "        PyObject * disp_dict;\n"
327            for par in self.disp_params:
328                par = par.strip()
329                param_str += "        disp_dict = PyDict_New();\n"
330                param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
331                param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
332               
333            # Initialize dispersion object dictionnary
334            param_str += "\n"
335           
336               
337            newline = self.replaceToken(newline,
338                                        "[INITDICTIONARY]", param_str)
339           
340            # Read dictionary
341            param_str = "    // Reader parameter dictionary\n"
342            for par in self.params:
343                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
344                    (par, par)
345                   
346            param_str += "    // Read in dispersion parameters\n"
347            param_str += "    PyObject* disp_dict;\n"
348            param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
349            for par in self.disp_params:
350                par = par.strip()
351                param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
352                param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
353               
354            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
355               
356            # Name of .c file
357            #toks = string.split(self.file,'.')
358            basename = os.path.basename(self.file)
359            toks = basename.split('.')
360            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
361           
362            # Include file
363            basename = os.path.basename(self.file)
364            newline = self.replaceToken(newline, 
[a1d1b90]365                                        "[INCLUDE_FILE]", self.file) 
366            if self.foundCPP:
367                newline = self.replaceToken(newline, 
368                                            "[C_INCLUDE_FILE]", "") 
369                newline = self.replaceToken(newline, 
370                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
371            else: 
372                newline = self.replaceToken(newline, 
373                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
374                newline = self.replaceToken(newline, 
375                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
[af03ddd]376               
377            # Numerical calcs dealloc
378            dealloc_str = "\n"
379            if self.modelCalcFlag:
380                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
381            newline = self.replaceToken(newline, 
382                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
383               
384            # Numerical calcs init
385            init_str = "\n"
386            if self.modelCalcFlag:
387                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
388            newline = self.replaceToken(newline, 
389                                        "[NUMERICAL_INIT]", init_str)     
390               
391            # Numerical calcs reset
392            reset_str = "\n"
393            if self.modelCalcFlag:
394                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
395            newline = self.replaceToken(newline, 
396                                        "[NUMERICAL_RESET]", reset_str)     
397               
398            # Setting dispsertion weights
399            set_weights = "    // Ugliness necessary to go from python to C\n"
400            set_weights = "    // TODO: refactor this\n"
401            for par in self.disp_params:
402                par = par.strip()
403                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
404                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
405                set_weights += "    } else"
406            newline = self.replaceToken(newline, 
407                                        "[SET_DISPERSION]", set_weights)     
408           
409            # Write new line to the wrapper .c file
410            file.write(newline+'\n')
411           
412           
413        file.close()
414       
415    def write_python_wrapper(self):
416        """ Writes the python file to create the python extension class
417            The file is written in ../[PYTHONCLASS].py
418        """
[2d1b700]419        file_path = os.path.join(self.output_dir, "../sans/models/"+self.pythonClass+'.py')
420        file = open(file_path, 'w')
421        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
[af03ddd]422       
423        tmp_buf = template.read()
424        tmp_lines = tmp_buf.split('\n')
425       
426        for tmp_line in tmp_lines:
427           
428            # Catch class name
429            newline = self.replaceToken(tmp_line, 
430                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
431           
432            # Catch class name
433            newline = self.replaceToken(newline, 
434                                        "[PYTHONCLASS]", self.pythonClass)
435           
436            # Include file
437            newline = self.replaceToken(newline, 
438                                        "[INCLUDE_FILE]", self.file)   
439                   
440            # Include file
441            newline = self.replaceToken(newline, 
442                                        "[DEFAULT_LIST]", self.default_list)
[95986b5]443            # model description
[af03ddd]444            newline = self.replaceToken(newline, 
[95986b5]445                                        "[DESCRIPTION]", self.description)
[4e2f6ef8]446            # Parameter details
447            newline = self.replaceToken(newline, 
[95986b5]448                                        "[PAR_DETAILS]", self.details)
449           
[836fe6e]450            # fixed list  details
451            newline = self.replaceToken(newline, 
452                                        "[FIXED]",str(self.fixed))
[35aface]453            # non-fittable list  details
454            newline = self.replaceToken(newline, 
455                                        "[NON_FITTABLE_PARAMS]",str(self.non_fittable))
[25a608f5]456            ## parameters with orientation
457       
458            newline = self.replaceToken(newline, 
459                               "[ORIENTATION_PARAMS]",str(self.orientation_params))
[836fe6e]460           
[af03ddd]461            # Write new line to the wrapper .c file
462            file.write(newline+'\n')
463               
464        file.close()
465       
466       
467    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
468        """ Replace a token in the template file
469            @param line: line of text to inspect
470            @param key: token to look for
471            @param value: string value to replace the token with
472            @return: new string value
473        """
474        lenkey = len(key)
475        newline = line
[836fe6e]476       
[af03ddd]477        while newline.count(key)>0:
478            index = newline.index(key)
479            newline = newline[:index]+value+newline[index+lenkey:]
[836fe6e]480       
[af03ddd]481        return newline
482       
483       
484# main
485if __name__ == '__main__':
486    if len(sys.argv)>1:
487        print "Will look for file %s" % sys.argv[1]
[95986b5]488    #app = WrapperGenerator('../c_extensions/elliptical_cylinder.h')
[af03ddd]489        app = WrapperGenerator(sys.argv[1])
490    else:
491        app = WrapperGenerator("test.h")
492    app.read()
493    app.write_c_wrapper()
494    app.write_python_wrapper()
495    print app
496   
497# End of file       
Note: See TracBrowser for help on using the repository browser.