source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ fa6db8b

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since fa6db8b was fa6db8b, checked in by Kieran Campbell <kieranrcampbell@…>, 12 years ago

Added mechanism for model categorization and DAB Model python → c++

  • Property mode set to 100644
File size: 21.3 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        # Model category
93        self.category = None
94        ## output directory for wrappers
95        self.output_dir = output_dir
96        self.c_wrapper_dir = c_wrapper_dir
97
98       
99       
100    def __repr__(self):
101        """ Simple output for printing """
102       
103        rep  = "\n Python class: %s\n\n" % self.pythonClass
104        rep += "  struc name: %s\n\n" % self.structName
105        rep += "  params:     %s\n\n" % self.params
106        rep += "  description:    %s\n\n" % self.description
107        rep += "  Fittable parameters:     %s\n\n"% self.fixed
108        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
109        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
110        return rep
111       
112    def read(self):
113        """ Reads in the .h file to catch parameters of the wrapper """
114       
115        # Check if the file is there
116        if not os.path.isfile(self.file):
117            raise ValueError, "File %s is not a regular file" % self.file
118       
119        # Read file
120        f = open(self.file,'r')
121        buf = f.read()
122       
123        self.default_list = "List of default parameters:\n"
124        #lines = string.split(buf,'\n')
125        lines = buf.split('\n')
126        self.details  = "## Parameter details [units, min, max]\n"
127        self.details += "        self.details = {}\n"
128       
129        #open item in this case Fixed
130        text='text'
131        key2="<%s>"%text.lower()
132        # close an item in this case fixed
133        text='TexT'
134        key3="</%s>"%text.lower()
135       
136        ## Catch fixed parameters
137        key = "[FIXED]"
138        try:
139            self.fixed= lineparser.readhelper(lines, key, 
140                                              key2, key3, file=self.file)
141        except:
142           raise   
143        ## Catch non-fittable parameters parameters
144        key = "[NON_FITTABLE_PARAMS]"
145        try:
146            self.non_fittable= lineparser.readhelper(lines, key, key2,
147                                                     key3, file=self.file)
148        except:
149           raise   
150
151        ## Catch parameters with orientation
152        key = "[ORIENTATION_PARAMS]"   
153        try:
154            self.orientation_params = lineparser.readhelper(lines, key, 
155                                                    key2, key3, file=self.file)
156        except:
157           raise 
158        ## Catch Description
159        key = "[DESCRIPTION]"
160       
161        find_description = False
162        temp=""
163        for line in lines:
164            if line.count(key)>0 :
165               
166                try:
167                    find_description= True
168                    index = line.index(key)
169                    toks = line[index:].split("=",1 )
170                    temp=toks[1].lstrip().rstrip()
171                    text='text'
172                    key2="<%s>"%text.lower()
173                    if re.match(key2,temp)!=None:
174   
175                        toks2=temp.split(key2,1)
176                        self.description=toks2[1]
177                        text='text'
178                        key2="</%s>"%text.lower()
179                        if re.search(key2,toks2[1])!=None:
180                            temp=toks2[1].split(key2,1)
181                            self.description=temp[0]
182                            break
183                     
184                    else:
185                        self.description=temp
186                        break
187                except:
188                     raise ValueError, "Could not parse file %s" % self.file
189            elif find_description:
190                text='text'
191                key2="</%s>"%text.lower()
192                if re.search(key2,line)!=None:
193                    tok=line.split(key2,1)
194                    temp=tok[0].split("//",1)
195                    self.description+=tok[1].lstrip().rstrip()
196                    break
197                else:
198                    if re.search("//",line)!=None:
199                        temp=line.split("//",1)
200                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
201                       
202                    else:
203                        self.description+='\n\t\t'+line.lstrip().rstrip()
204                   
205               
206     
207        for line in lines:
208           
209            # Catch class name
210            key = "[PYTHONCLASS]"
211            if line.count(key)>0:
212                try:
213                    index = line.index(key)
214                    toks = line[index:].split("=" )
215                    self.pythonClass = toks[1].lstrip().rstrip()
216
217                except:
218                    raise ValueError, "Could not parse file %s" % self.file
219               
220            key = "[CATEGORY]"
221            if line.count(key)>0:
222                try:
223                    index = line.index(key)
224                    toks = line[index:].split("=")
225                    self.category = toks[1].lstrip().rstrip()
226
227                except:
228                    raise ValueError, "Could not parse file %s" % self.file
229
230
231
232            # Catch struct name
233            # C++ class definition
234            if line.count("class")>0:
235                # We are entering a class definition
236                self.inParDefs = True
237                self.foundCPP = True
238               
239            # Old-Style C struct definition
240            if line.count("typedef struct")>0:
241                # We are entering a struct block
242                self.inParDefs = True
243                self.inStruct = True
244           
245            if self.inParDefs and line.count("}")>0:
246                # We are exiting a struct block
247                self.inParDefs = False
248               
249                if self.inStruct:
250                    self.inStruct = False
251                    # Catch the name of the struct
252                    index = line.index("}")
253                    toks = line[index+1:].split(";")
254                    # Catch pointer definition
255                    toks2 = toks[0].split(',')
256                    self.structName = toks2[0].lstrip().rstrip()
257           
258            # Catch struct content
259            key = "[DEFAULT]"
260            if self.inParDefs and line.count(key)>0:
261                # Found a new parameter
262                try:
263                    index = line.index(key)
264                    toks = line[index:].split("=")
265                    toks2 = toks[2].split()
266                    val = float(toks2[0])
267                    self.params[toks[1]] = val
268                    #self.pythonClass = toks[1].lstrip().rstrip()
269                    units = ""
270                    if len(toks2) >= 2:
271                        units = toks2[1]
272                    self.default_list += "         %-15s = %s %s\n" % \
273                        (toks[1], val, units)
274                   
275                    # Check for min and max
276                    min = "None"
277                    max = "None"
278                    if len(toks2) == 4:
279                        min = toks2[2]
280                        max = toks2[3]
281                   
282                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
283                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
284                except:
285                    raise ValueError, "Could not parse input file %s \n  %s" % \
286                        (self.file, sys.exc_value)
287               
288               
289            # Catch need for numerical calculations
290            key = "CalcParameters calcPars"
291            if line.count(key)>0:
292                self.modelCalcFlag = True
293               
294            # Catch list of dispersed parameters
295            key = "[DISP_PARAMS]"
296            if line.count(key)>0:
297                try:
298                    index = line.index(key)
299                    toks = line[index:].split("=")
300                    list_str = toks[1].lstrip().rstrip()
301                    self.disp_params = list_str.split(',')
302                except:
303                    raise ValueError, "Could not parse file %s" % self.file
304
305    def write_c_wrapper(self):
306        """ Writes the C file to create the python extension class
307            The file is written in C[PYTHONCLASS].c
308        """
309        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
310        file = open(file_path, 'w')
311       
312        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
313       
314        tmp_buf = template.read()
315        #tmp_lines = string.split(tmp_buf,'\n')
316        tmp_lines = tmp_buf.split('\n')
317       
318        for tmp_line in tmp_lines:
319           
320            # Catch class name
321            newline = self.replaceToken(tmp_line, 
322                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
323            #Catch model description
324            #newline = self.replaceToken(tmp_line,
325            #                            "[DESCRIPTION]", self.description)
326            # Catch C model name
327            newline = self.replaceToken(newline, 
328                                        "[CMODEL]", self.pythonClass)
329           
330            # Catch class name
331            newline = self.replaceToken(newline, 
332                                        "[MODELSTRUCT]", self.structName)
333           
334            # Dictionary initialization
335            param_str = "// Initialize parameter dictionary\n"           
336            for par in self.params:
337                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
338                    (par, self.params[par])
339
340            if len(self.disp_params)>0:
341                param_str += "        // Initialize dispersion / averaging parameter dict\n"
342                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
343                param_str += "        PyObject * disp_dict;\n"
344                for par in self.disp_params:
345                    par = par.strip()
346                    param_str += "        disp_dict = PyDict_New();\n"
347                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
348                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
349               
350            # Initialize dispersion object dictionnary
351            param_str += "\n"
352           
353               
354            newline = self.replaceToken(newline,
355                                        "[INITDICTIONARY]", param_str)
356           
357            # Read dictionary
358            param_str = "    // Reader parameter dictionary\n"
359            for par in self.params:
360                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
361                    (par, par)
362                   
363            if len(self.disp_params)>0:
364                param_str += "    // Read in dispersion parameters\n"
365                param_str += "    PyObject* disp_dict;\n"
366                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
367                for par in self.disp_params:
368                    par = par.strip()
369                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
370                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
371               
372            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
373               
374            # Name of .c file
375            #toks = string.split(self.file,'.')
376            basename = os.path.basename(self.file)
377            toks = basename.split('.')
378            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
379           
380            # Include file
381            basename = os.path.basename(self.file)
382            newline = self.replaceToken(newline, 
383                                        "[INCLUDE_FILE]", self.file) 
384            if self.foundCPP:
385                newline = self.replaceToken(newline, 
386                                            "[C_INCLUDE_FILE]", "") 
387                newline = self.replaceToken(newline, 
388                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
389            else: 
390                newline = self.replaceToken(newline, 
391                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
392                newline = self.replaceToken(newline, 
393                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
394               
395            # Numerical calcs dealloc
396            dealloc_str = "\n"
397            if self.modelCalcFlag:
398                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
399            newline = self.replaceToken(newline, 
400                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
401               
402            # Numerical calcs init
403            init_str = "\n"
404            if self.modelCalcFlag:
405                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
406            newline = self.replaceToken(newline, 
407                                        "[NUMERICAL_INIT]", init_str)     
408               
409            # Numerical calcs reset
410            reset_str = "\n"
411            if self.modelCalcFlag:
412                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
413            newline = self.replaceToken(newline, 
414                                        "[NUMERICAL_RESET]", reset_str)     
415               
416            # Setting dispsertion weights
417            set_weights = "    // Ugliness necessary to go from python to C\n"
418            set_weights = "    // TODO: refactor this\n"
419            for par in self.disp_params:
420                par = par.strip()
421                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
422                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
423                set_weights += "    } else"
424            newline = self.replaceToken(newline, 
425                                        "[SET_DISPERSION]", set_weights)     
426           
427            # Write new line to the wrapper .c file
428            file.write(newline+'\n')
429           
430           
431        file.close()
432       
433    def write_python_wrapper(self):
434        """ Writes the python file to create the python extension class
435            The file is written in ../[PYTHONCLASS].py
436        """
437        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
438        file = open(file_path, 'w')
439        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
440       
441        tmp_buf = template.read()
442        tmp_lines = tmp_buf.split('\n')
443       
444        for tmp_line in tmp_lines:
445           
446            # Catch class name
447            newline = self.replaceToken(tmp_line, 
448                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
449           
450            # Catch class name
451            newline = self.replaceToken(newline, 
452                                        "[PYTHONCLASS]", self.pythonClass)
453           
454            # Include file
455            newline = self.replaceToken(newline, 
456                                        "[INCLUDE_FILE]", self.file)   
457                   
458            # Include file
459            newline = self.replaceToken(newline, 
460                                        "[DEFAULT_LIST]", self.default_list)
461            # model description
462            newline = self.replaceToken(newline, 
463                                        "[DESCRIPTION]", self.description)
464            # Parameter details
465            newline = self.replaceToken(newline, 
466                                        "[PAR_DETAILS]", self.details)
467           
468            # fixed list  details
469            fixed_str = str(self.fixed)
470            fixed_str = fixed_str.replace(', ', ',\n                      ')
471            newline = self.replaceToken(newline, "[FIXED]",fixed_str)
472           
473            # non-fittable list details
474            pars_str = str(self.non_fittable)
475            pars_str = pars_str.replace(', ', 
476                                        ',\n                             ')
477            newline = self.replaceToken(newline, 
478                                        "[NON_FITTABLE_PARAMS]",
479                                        pars_str)
480           
481            ## parameters with orientation
482            oriented_str = str(self.orientation_params)
483            formatted_endl = ',\n                                   '
484            oriented_str = oriented_str.replace(', ', formatted_endl)
485            newline = self.replaceToken(newline, 
486                               "[ORIENTATION_PARAMS]", oriented_str)
487           
488            if self.category:
489                newline = self.replaceToken(newline,"[CATEGORY]"
490                                            ,'"' + self.category + '"')
491            else:
492                newline = self.replaceToken(newline,"[CATEGORY]",
493                                            "None")
494           
495
496
497            # Write new line to the wrapper .c file
498            file.write(newline+'\n')
499               
500        file.close()
501       
502       
503    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
504        """ Replace a token in the template file
505            @param line: line of text to inspect
506            @param key: token to look for
507            @param value: string value to replace the token with
508            @return: new string value
509        """
510        lenkey = len(key)
511        newline = line
512       
513        while newline.count(key)>0:
514            index = newline.index(key)
515            newline = newline[:index]+value+newline[index+lenkey:]
516       
517        return newline
518       
519    def getModelName(self):
520        return self.pythonClass
521       
522
523
524# main
525if __name__ == '__main__':
526    if len(sys.argv)>1:
527        print "Will look for file %s" % sys.argv[1]
528        app = WrapperGenerator(sys.argv[1])
529    else:
530        app = WrapperGenerator("test.h")
531    app.read()
532    app.write_c_wrapper()
533    app.write_python_wrapper()
534    print app
535   
536# End of file       
Note: See TracBrowser for help on using the repository browser.