source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 325bc4a

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 325bc4a was 5da3cc5, checked in by Kieran Campbell <kieranrcampbell@…>, 12 years ago

Implemented erf(x) in libfunc.c and added pass down to C++ of multfactor

  • Property mode set to 100644
File size: 22.9 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46            /// 1D scattering function
47            double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49            /// 2D scattering function
50            double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        # Model category
93        self.category = None
94        # Whether model belongs to multifunc
95        self.is_multifunc = False
96        ## output directory for wrappers
97        self.output_dir = output_dir
98        self.c_wrapper_dir = c_wrapper_dir
99
100       
101       
102    def __repr__(self):
103        """ Simple output for printing """
104       
105        rep  = "\n Python class: %s\n\n" % self.pythonClass
106        rep += "  struc name: %s\n\n" % self.structName
107        rep += "  params:     %s\n\n" % self.params
108        rep += "  description:    %s\n\n" % self.description
109        rep += "  Fittable parameters:     %s\n\n"% self.fixed
110        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
111        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
112        return rep
113       
114    def read(self):
115        """ Reads in the .h file to catch parameters of the wrapper """
116       
117        # Check if the file is there
118        if not os.path.isfile(self.file):
119            raise ValueError, "File %s is not a regular file" % self.file
120       
121        # Read file
122        f = open(self.file,'r')
123        buf = f.read()
124       
125        self.default_list = "List of default parameters:\n"
126        #lines = string.split(buf,'\n')
127        lines = buf.split('\n')
128        self.details  = "## Parameter details [units, min, max]\n"
129        self.details += "        self.details = {}\n"
130       
131        #open item in this case Fixed
132        text='text'
133        key2="<%s>"%text.lower()
134        # close an item in this case fixed
135        text='TexT'
136        key3="</%s>"%text.lower()
137       
138        ## Catch fixed parameters
139        key = "[FIXED]"
140        try:
141            self.fixed= lineparser.readhelper(lines, key, 
142                                              key2, key3, file=self.file)
143        except:
144           raise   
145        ## Catch non-fittable parameters parameters
146        key = "[NON_FITTABLE_PARAMS]"
147        try:
148            self.non_fittable= lineparser.readhelper(lines, key, key2,
149                                                     key3, file=self.file)
150        except:
151           raise   
152
153        ## Catch parameters with orientation
154        key = "[ORIENTATION_PARAMS]"   
155        try:
156            self.orientation_params = lineparser.readhelper(lines, key, 
157                                                    key2, key3, file=self.file)
158        except:
159           raise 
160        ## Catch Description
161        key = "[DESCRIPTION]"
162       
163        find_description = False
164        temp=""
165        for line in lines:
166            if line.count(key)>0 :
167               
168                try:
169                    find_description= True
170                    index = line.index(key)
171                    toks = line[index:].split("=",1 )
172                    temp=toks[1].lstrip().rstrip()
173                    text='text'
174                    key2="<%s>"%text.lower()
175                    if re.match(key2,temp)!=None:
176   
177                        toks2=temp.split(key2,1)
178                        self.description=toks2[1]
179                        text='text'
180                        key2="</%s>"%text.lower()
181                        if re.search(key2,toks2[1])!=None:
182                            temp=toks2[1].split(key2,1)
183                            self.description=temp[0]
184                            break
185                     
186                    else:
187                        self.description=temp
188                        break
189                except:
190                     raise ValueError, "Could not parse file %s" % self.file
191            elif find_description:
192                text='text'
193                key2="</%s>"%text.lower()
194                if re.search(key2,line)!=None:
195                    tok=line.split(key2,1)
196                    temp=tok[0].split("//",1)
197                    self.description+=tok[1].lstrip().rstrip()
198                    break
199                else:
200                    if re.search("//",line)!=None:
201                        temp=line.split("//",1)
202                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
203                       
204                    else:
205                        self.description+='\n\t\t'+line.lstrip().rstrip()
206                   
207               
208     
209        for line in lines:
210           
211            # Catch class name
212            key = "[PYTHONCLASS]"
213            if line.count(key)>0:
214                try:
215                    index = line.index(key)
216                    toks = line[index:].split("=" )
217                    self.pythonClass = toks[1].lstrip().rstrip()
218
219                except:
220                    raise ValueError, "Could not parse file %s" % self.file
221               
222            key = "[CATEGORY]"
223            if line.count(key)>0:
224                try:
225                    index = line.index(key)
226                    toks = line[index:].split("=")
227                    self.category = toks[1].lstrip().rstrip()
228
229                except:
230                    raise ValueError, "Could not parse file %s" % self.file
231
232            # is_multifunc
233            key = "[MULTIPLICITY_INFO]"
234            if line.count(key) > 0:
235                self.is_multifunc = True
236                try:
237                    index = line.index(key)
238                    toks = line[index:].split("=")
239                    self.multiplicity_info = toks[1].lstrip().rstrip()
240                except:
241                    raise ValueError, "Could not parse file %s" % self.file
242
243            # Catch struct name
244            # C++ class definition
245            if line.count("class")>0:
246                # We are entering a class definition
247                self.inParDefs = True
248                self.foundCPP = True
249               
250            # Old-Style C struct definition
251            if line.count("typedef struct")>0:
252                # We are entering a struct block
253                self.inParDefs = True
254                self.inStruct = True
255           
256            if self.inParDefs and line.count("}")>0:
257                # We are exiting a struct block
258                self.inParDefs = False
259               
260                if self.inStruct:
261                    self.inStruct = False
262                    # Catch the name of the struct
263                    index = line.index("}")
264                    toks = line[index+1:].split(";")
265                    # Catch pointer definition
266                    toks2 = toks[0].split(',')
267                    self.structName = toks2[0].lstrip().rstrip()
268           
269            # Catch struct content
270            key = "[DEFAULT]"
271            if self.inParDefs and line.count(key)>0:
272                # Found a new parameter
273                try:
274                    index = line.index(key)
275                    toks = line[index:].split("=")
276                    toks2 = toks[2].split()
277                    val = float(toks2[0])
278                    self.params[toks[1]] = val
279                    #self.pythonClass = toks[1].lstrip().rstrip()
280                    units = ""
281                    if len(toks2) >= 2:
282                        units = toks2[1]
283                    self.default_list += "         %-15s = %s %s\n" % \
284                        (toks[1], val, units)
285                   
286                    # Check for min and max
287                    min = "None"
288                    max = "None"
289                    if len(toks2) == 4:
290                        min = toks2[2]
291                        max = toks2[3]
292                   
293                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
294                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
295                except:
296                    raise ValueError, "Could not parse input file %s \n  %s" % \
297                        (self.file, sys.exc_value)
298               
299               
300            # Catch need for numerical calculations
301            key = "CalcParameters calcPars"
302            if line.count(key)>0:
303                self.modelCalcFlag = True
304               
305            # Catch list of dispersed parameters
306            key = "[DISP_PARAMS]"
307            if line.count(key)>0:
308                try:
309                    index = line.index(key)
310                    toks = line[index:].split("=")
311                    list_str = toks[1].lstrip().rstrip()
312                    self.disp_params = list_str.split(',')
313                except:
314                    raise ValueError, "Could not parse file %s" % self.file
315
316    def write_c_wrapper(self):
317        """ Writes the C file to create the python extension class
318            The file is written in C[PYTHONCLASS].c
319        """
320        file_path = os.path.join(self.c_wrapper_dir, "C"+self.pythonClass+'.cpp')
321        file = open(file_path, 'w')
322       
323        template = open(os.path.join(os.path.dirname(__file__), "classTemplate.txt"), 'r')
324       
325        tmp_buf = template.read()
326        #tmp_lines = string.split(tmp_buf,'\n')
327        tmp_lines = tmp_buf.split('\n')
328       
329        for tmp_line in tmp_lines:
330           
331            # Catch class name
332            newline = self.replaceToken(tmp_line, 
333                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
334            #Catch model description
335            #newline = self.replaceToken(tmp_line,
336            #                            "[DESCRIPTION]", self.description)
337            # Catch C model name
338            newline = self.replaceToken(newline, 
339                                        "[CMODEL]", self.pythonClass)
340           
341            # Catch class name
342            newline = self.replaceToken(newline, 
343                                        "[MODELSTRUCT]", self.structName)
344
345            # Sort model initialization based on multifunc
346            if(self.is_multifunc):
347                line = "int level = 1;\nPyArg_ParseTuple(args,\"i\",&level);\n"
348                line += "self->model = new " + self.pythonClass + "(level);"
349            else:
350                line = "self->model = new " + self.pythonClass + "();"
351   
352            newline = self.replaceToken(newline,"[INITIALIZE_MODEL]",
353                                            line)
354           
355            # Dictionary initialization
356            param_str = "// Initialize parameter dictionary\n"           
357            for par in self.params:
358                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
359                    (par, self.params[par])
360
361            if len(self.disp_params)>0:
362                param_str += "        // Initialize dispersion / averaging parameter dict\n"
363                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
364                param_str += "        PyObject * disp_dict;\n"
365                for par in self.disp_params:
366                    par = par.strip()
367                    param_str += "        disp_dict = PyDict_New();\n"
368                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
369                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
370               
371            # Initialize dispersion object dictionnary
372            param_str += "\n"
373           
374               
375            newline = self.replaceToken(newline,
376                                        "[INITDICTIONARY]", param_str)
377           
378            # Read dictionary
379            param_str = "    // Reader parameter dictionary\n"
380            for par in self.params:
381                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
382                    (par, par)
383                   
384            if len(self.disp_params)>0:
385                param_str += "    // Read in dispersion parameters\n"
386                param_str += "    PyObject* disp_dict;\n"
387                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
388                for par in self.disp_params:
389                    par = par.strip()
390                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
391                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
392               
393            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
394               
395            # Name of .c file
396            #toks = string.split(self.file,'.')
397            basename = os.path.basename(self.file)
398            toks = basename.split('.')
399            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
400           
401            # Include file
402            basename = os.path.basename(self.file)
403            newline = self.replaceToken(newline, 
404                                        "[INCLUDE_FILE]", self.file) 
405            if self.foundCPP:
406                newline = self.replaceToken(newline, 
407                                            "[C_INCLUDE_FILE]", "") 
408                newline = self.replaceToken(newline, 
409                                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
410            else: 
411                newline = self.replaceToken(newline, 
412                                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
413                newline = self.replaceToken(newline, 
414                                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
415               
416            # Numerical calcs dealloc
417            dealloc_str = "\n"
418            if self.modelCalcFlag:
419                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
420            newline = self.replaceToken(newline, 
421                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
422               
423            # Numerical calcs init
424            init_str = "\n"
425            if self.modelCalcFlag:
426                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
427            newline = self.replaceToken(newline, 
428                                        "[NUMERICAL_INIT]", init_str)     
429               
430            # Numerical calcs reset
431            reset_str = "\n"
432            if self.modelCalcFlag:
433                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
434            newline = self.replaceToken(newline, 
435                                        "[NUMERICAL_RESET]", reset_str)     
436               
437            # Setting dispsertion weights
438            set_weights = "    // Ugliness necessary to go from python to C\n"
439            set_weights = "    // TODO: refactor this\n"
440            for par in self.disp_params:
441                par = par.strip()
442                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
443                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
444                set_weights += "    } else"
445            newline = self.replaceToken(newline, 
446                                        "[SET_DISPERSION]", set_weights)     
447           
448            # Write new line to the wrapper .c file
449            file.write(newline+'\n')
450           
451           
452        file.close()
453       
454    def write_python_wrapper(self):
455        """ Writes the python file to create the python extension class
456            The file is written in ../[PYTHONCLASS].py
457        """
458        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
459        file = open(file_path, 'w')
460        template = open(os.path.join(os.path.dirname(__file__), "modelTemplate.txt"), 'r')
461       
462        tmp_buf = template.read()
463        tmp_lines = tmp_buf.split('\n')
464       
465        for tmp_line in tmp_lines:
466           
467            # Catch class name
468            newline = self.replaceToken(tmp_line, 
469                                        "[CPYTHONCLASS]", 'C'+self.pythonClass)
470           
471            # Catch class name
472            newline = self.replaceToken(newline, 
473                                        "[PYTHONCLASS]", self.pythonClass)
474           
475            # Include file
476            newline = self.replaceToken(newline, 
477                                        "[INCLUDE_FILE]", self.file)   
478                   
479            # Include file
480            newline = self.replaceToken(newline, 
481                                        "[DEFAULT_LIST]", self.default_list)
482            # model description
483            newline = self.replaceToken(newline, 
484                                        "[DESCRIPTION]", self.description)
485            # Parameter details
486            newline = self.replaceToken(newline, 
487                                        "[PAR_DETAILS]", self.details)
488           
489            # Call base constructor
490            if self.is_multifunc:
491                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
492                                            'C' + self.pythonClass + ".__init__(self,multfactor)\n\tself.is_multifunc = True")
493                newline = self.replaceToken(newline,"[MULTIPLICITY_INFO]",self.multiplicity_info)
494            else:
495                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
496                                            'C' + self.pythonClass + ".__init__(self)\n\tself.is_multifunc = False")
497                newline = self.replaceToken(newline,"[MULTIPLICITY_INFO]","None")
498
499           
500            # fixed list  details
501            fixed_str = str(self.fixed)
502            fixed_str = fixed_str.replace(', ', ',\n                      ')
503            newline = self.replaceToken(newline, "[FIXED]",fixed_str)
504           
505            # non-fittable list details
506            pars_str = str(self.non_fittable)
507            pars_str = pars_str.replace(', ', 
508                                        ',\n                             ')
509            newline = self.replaceToken(newline, 
510                                        "[NON_FITTABLE_PARAMS]",
511                                        pars_str)
512           
513            ## parameters with orientation
514            oriented_str = str(self.orientation_params)
515            formatted_endl = ',\n                                   '
516            oriented_str = oriented_str.replace(', ', formatted_endl)
517            newline = self.replaceToken(newline, 
518                               "[ORIENTATION_PARAMS]", oriented_str)
519           
520            if self.category:
521                newline = self.replaceToken(newline,"[CATEGORY]"
522                                            ,'"' + self.category + '"')
523            else:
524                newline = self.replaceToken(newline,"[CATEGORY]",
525                                            "None")
526           
527
528
529            # Write new line to the wrapper .c file
530            file.write(newline+'\n')
531               
532        file.close()
533       
534       
535    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
536        """ Replace a token in the template file
537            @param line: line of text to inspect
538            @param key: token to look for
539            @param value: string value to replace the token with
540            @return: new string value
541        """
542        lenkey = len(key)
543        newline = line
544       
545        while newline.count(key)>0:
546            index = newline.index(key)
547            newline = newline[:index]+value+newline[index+lenkey:]
548       
549        return newline
550       
551    def getModelName(self):
552        return self.pythonClass
553       
554
555
556# main
557if __name__ == '__main__':
558    if len(sys.argv)>1:
559        print "Will look for file %s" % sys.argv[1]
560        app = WrapperGenerator(sys.argv[1])
561    else:
562        app = WrapperGenerator("test.h")
563    app.read()
564    app.write_c_wrapper()
565    app.write_python_wrapper()
566    print app
567   
568# End of file       
Note: See TracBrowser for help on using the repository browser.