source: sasview/src/sas/models/c_extension/python_wrapper/WrapperGenerator.py @ 6bd3a8d1

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6bd3a8d1 was fd5ac0d, checked in by krzywon, 10 years ago

I have completed the removal of all SANS references.
I will build, run, and run all unit tests before pushing.

  • Property mode set to 100644
File size: 23.8 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys,re
6import lineparser
7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sas.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sas.models.MyModel.
24          It will also create a class CMyModel in
25          sas_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
46        /// 1D scattering function
47        double cylinder_analytical_1D(CylinderParameters *pars, double q);
48           
49        /// 2D scattering function
50        double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
70        self.foundCPP = False
71        self.inParDefs = False
72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
84        #model description
85        self.description=''
86        # paramaters for fittable
87        self.fixed= []
88        # paramaters for non-fittable
89        self.non_fittable= []
90        ## parameters with orientation
91        self.orientation_params =[]
92        ## parameter with magnetism
93        self.magentic_params = []
94        # Model category
95        self.category = None
96        # Whether model belongs to multifunc
97        self.is_multifunc = False
98        ## output directory for wrappers
99        self.output_dir = output_dir
100        self.c_wrapper_dir = c_wrapper_dir
101
102       
103       
104    def __repr__(self):
105        """ Simple output for printing """
106       
107        rep  = "\n Python class: %s\n\n" % self.pythonClass
108        rep += "  struc name: %s\n\n" % self.structName
109        rep += "  params:     %s\n\n" % self.params
110        rep += "  description:    %s\n\n" % self.description
111        rep += "  Fittable parameters:     %s\n\n"% self.fixed
112        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
113        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
114        rep += "  Magnetic parameters:  %s\n\n"% self.magnetic_params
115        return rep
116       
117    def read(self):
118        """ Reads in the .h file to catch parameters of the wrapper """
119       
120        # Check if the file is there
121        if not os.path.isfile(self.file):
122            raise ValueError, "File %s is not a regular file" % self.file
123       
124        # Read file
125        f = open(self.file,'r')
126        buf = f.read()
127       
128        self.default_list = "\n    List of default parameters:\n\n"
129        #lines = string.split(buf,'\n')
130        lines = buf.split('\n')
131        self.details  = "## Parameter details [units, min, max]\n"
132        self.details += "        self.details = {}\n"
133       
134        #open item in this case Fixed
135        text='text'
136        key2="<%s>"%text.lower()
137        # close an item in this case fixed
138        text='TexT'
139        key3="</%s>"%text.lower()
140       
141        ## Catch fixed parameters
142        key = "[FIXED]"
143        try:
144            self.fixed= lineparser.readhelper(lines, key, 
145                                              key2, key3, file=self.file)
146        except:
147           raise   
148        ## Catch non-fittable parameters parameters
149        key = "[NON_FITTABLE_PARAMS]"
150        try:
151            self.non_fittable= lineparser.readhelper(lines, key, key2,
152                                                     key3, file=self.file)
153        except:
154           raise   
155
156        ## Catch parameters with orientation
157        key = "[ORIENTATION_PARAMS]"   
158        try:
159            self.orientation_params = lineparser.readhelper(lines, key, 
160                                                    key2, key3, file=self.file)
161        except:
162           raise 
163       
164        ## Catch parameters with orientation
165        key = "[MAGNETIC_PARAMS]"   
166        try:
167            self.magnetic_params = lineparser.readhelper( lines,key, 
168                                                    key2,key3, file= self.file)
169        except:
170           raise 
171       
172        ## Catch Description
173        key = "[DESCRIPTION]"
174       
175        find_description = False
176        temp=""
177        for line in lines:
178            if line.count(key)>0 :
179               
180                try:
181                    find_description= True
182                    index = line.index(key)
183                    toks = line[index:].split("=",1 )
184                    temp=toks[1].lstrip().rstrip()
185                    text='text'
186                    key2="<%s>"%text.lower()
187                    if re.match(key2,temp)!=None:
188   
189                        toks2=temp.split(key2,1)
190                        self.description=toks2[1]
191                        text='text'
192                        key2="</%s>"%text.lower()
193                        if re.search(key2,toks2[1])!=None:
194                            temp=toks2[1].split(key2,1)
195                            self.description=temp[0]
196                            break
197                     
198                    else:
199                        self.description=temp
200                        break
201                except:
202                     raise ValueError, "Could not parse file %s" % self.file
203            elif find_description:
204                text='text'
205                key2="</%s>"%text.lower()
206                if re.search(key2,line)!=None:
207                    tok=line.split(key2,1)
208                    temp=tok[0].split("//",1)
209                    self.description+=tok[1].lstrip().rstrip()
210                    break
211                else:
212                    if re.search("//",line)!=None:
213                        temp=line.split("//",1)
214                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
215                       
216                    else:
217                        self.description+='\n\t\t'+line.lstrip().rstrip()
218                   
219               
220     
221        for line in lines:
222           
223            # Catch class name
224            key = "[PYTHONCLASS]"
225            if line.count(key)>0:
226                try:
227                    index = line.index(key)
228                    toks = line[index:].split("=" )
229                    self.pythonClass = toks[1].lstrip().rstrip()
230
231                except:
232                    raise ValueError, "Could not parse file %s" % self.file
233               
234            key = "[CATEGORY]"
235            if line.count(key)>0:
236                try:
237                    index = line.index(key)
238                    toks = line[index:].split("=")
239                    self.category = toks[1].lstrip().rstrip()
240
241                except:
242                    raise ValueError, "Could not parse file %s" % self.file
243
244            # is_multifunc
245            key = "[MULTIPLICITY_INFO]"
246            if line.count(key) > 0:
247                self.is_multifunc = True
248                try:
249                    index = line.index(key)
250                    toks = line[index:].split("=")
251                    self.multiplicity_info = toks[1].lstrip().rstrip()
252                except:
253                    raise ValueError, "Could not parse file %s" % self.file
254
255            # Catch struct name
256            # C++ class definition
257            if line.count("class")>0:
258                # We are entering a class definition
259                self.inParDefs = True
260                self.foundCPP = True
261               
262            # Old-Style C struct definition
263            if line.count("typedef struct")>0:
264                # We are entering a struct block
265                self.inParDefs = True
266                self.inStruct = True
267           
268            if self.inParDefs and line.count("}")>0:
269                # We are exiting a struct block
270                self.inParDefs = False
271               
272                if self.inStruct:
273                    self.inStruct = False
274                    # Catch the name of the struct
275                    index = line.index("}")
276                    toks = line[index+1:].split(";")
277                    # Catch pointer definition
278                    toks2 = toks[0].split(',')
279                    self.structName = toks2[0].lstrip().rstrip()
280           
281            # Catch struct content
282            key = "[DEFAULT]"
283            if self.inParDefs and line.count(key)>0:
284                # Found a new parameter
285                try:
286                    index = line.index(key)
287                    toks = line[index:].split("=")
288                    toks2 = toks[2].split()
289                    val = float(toks2[0])
290                    self.params[toks[1]] = val
291                    #self.pythonClass = toks[1].lstrip().rstrip()
292                    units = ""
293                    if len(toks2) >= 2:
294                        units = toks2[1]
295                    self.default_list += "    * %-15s = %s %s\n" % \
296                        (toks[1], val, units)
297                   
298                    # Check for min and max
299                    min = "None"
300                    max = "None"
301                    if len(toks2) == 4:
302                        min = toks2[2]
303                        max = toks2[3]
304                   
305                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
306                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
307                except:
308                    raise ValueError, "Could not parse input file %s \n  %s" % \
309                        (self.file, sys.exc_value)
310               
311               
312            # Catch need for numerical calculations
313            key = "CalcParameters calcPars"
314            if line.count(key)>0:
315                self.modelCalcFlag = True
316               
317            # Catch list of dispersed parameters
318            key = "[DISP_PARAMS]"
319            if line.count(key)>0:
320                try:
321                    index = line.index(key)
322                    toks = line[index:].split("=")
323                    list_str = toks[1].lstrip().rstrip()
324                    self.disp_params = list_str.split(',')
325                except:
326                    raise ValueError, "Could not parse file %s" % self.file
327
328    def write_c_wrapper(self):
329        """ Writes the C file to create the python extension class
330            The file is written in C[PYTHONCLASS].c
331        """
332        file_path = os.path.join(self.c_wrapper_dir, 
333                                 "C"+self.pythonClass+'.cpp')
334        file = open(file_path, 'w')
335       
336        template = open(os.path.join(os.path.dirname(__file__), 
337                                     "classTemplate.txt"), 'r')
338       
339        tmp_buf = template.read()
340        #tmp_lines = string.split(tmp_buf,'\n')
341        tmp_lines = tmp_buf.split('\n')
342       
343        for tmp_line in tmp_lines:
344           
345            # Catch class name
346            newline = self.replaceToken(tmp_line, 
347                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
348            #Catch model description
349            #newline = self.replaceToken(tmp_line,
350            #                            "[DESCRIPTION]", self.description)
351            # Catch C model name
352            newline = self.replaceToken(newline, 
353                                        "[CMODEL]", self.pythonClass)
354           
355            # Catch class name
356            newline = self.replaceToken(newline, 
357                                        "[MODELSTRUCT]", self.structName)
358
359            # Sort model initialization based on multifunc
360            if(self.is_multifunc):
361                line = "int level = 1;\nPyArg_ParseTuple(args,\"i\",&level);\n"
362                line += "self->model = new " + self.pythonClass + "(level);"
363            else:
364                line = "self->model = new " + self.pythonClass + "();"
365   
366            newline = self.replaceToken(newline,"[INITIALIZE_MODEL]",
367                                            line)
368           
369            # Dictionary initialization
370            param_str = "// Initialize parameter dictionary\n"           
371            for par in self.params:
372                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
373                    (par, self.params[par])
374
375            if len(self.disp_params)>0:
376                param_str += "        // Initialize dispersion / averaging parameter dict\n"
377                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
378                param_str += "        PyObject * disp_dict;\n"
379                for par in self.disp_params:
380                    par = par.strip()
381                    if par == '':
382                        continue
383                    param_str += "        disp_dict = PyDict_New();\n"
384                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
385                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
386               
387            # Initialize dispersion object dictionnary
388            param_str += "\n"
389           
390               
391            newline = self.replaceToken(newline,
392                                        "[INITDICTIONARY]", param_str)
393           
394            # Read dictionary
395            param_str = "    // Reader parameter dictionary\n"
396            for par in self.params:
397                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
398                    (par, par)
399                   
400            if len(self.disp_params)>0:
401                param_str += "    // Read in dispersion parameters\n"
402                param_str += "    PyObject* disp_dict;\n"
403                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
404                for par in self.disp_params:
405                    par = par.strip()
406                    if par == '':
407                        continue
408                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
409                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
410               
411            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
412               
413            # Name of .c file
414            #toks = string.split(self.file,'.')
415            basename = os.path.basename(self.file)
416            toks = basename.split('.')
417            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
418           
419            # Include file
420            basename = os.path.basename(self.file)
421            newline = self.replaceToken(newline, 
422                                        "[INCLUDE_FILE]", self.file) 
423            if self.foundCPP:
424                newline = self.replaceToken(newline, 
425                            "[C_INCLUDE_FILE]", "") 
426                newline = self.replaceToken(newline, 
427                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
428            else: 
429                newline = self.replaceToken(newline, 
430                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
431                newline = self.replaceToken(newline, 
432                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
433               
434            # Numerical calcs dealloc
435            dealloc_str = "\n"
436            if self.modelCalcFlag:
437                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
438            newline = self.replaceToken(newline, 
439                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
440               
441            # Numerical calcs init
442            init_str = "\n"
443            if self.modelCalcFlag:
444                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
445            newline = self.replaceToken(newline, 
446                                        "[NUMERICAL_INIT]", init_str)     
447               
448            # Numerical calcs reset
449            reset_str = "\n"
450            if self.modelCalcFlag:
451                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
452            newline = self.replaceToken(newline, 
453                                        "[NUMERICAL_RESET]", reset_str)     
454               
455            # Setting dispsertion weights
456            set_weights = "    // Ugliness necessary to go from python to C\n"
457            set_weights = "    // TODO: refactor this\n"
458            for par in self.disp_params:
459                par = par.strip()
460                if par == '':
461                        continue
462                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
463                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
464                set_weights += "    } else"
465            newline = self.replaceToken(newline, 
466                                        "[SET_DISPERSION]", set_weights)     
467           
468            # Write new line to the wrapper .c file
469            file.write(newline+'\n')
470           
471           
472        file.close()
473       
474    def write_python_wrapper(self):
475        """ Writes the python file to create the python extension class
476            The file is written in ../[PYTHONCLASS].py
477        """
478        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
479        file = open(file_path, 'w')
480        template = open(os.path.join(os.path.dirname(__file__), 
481                                     "modelTemplate.txt"), 'r')
482       
483        tmp_buf = template.read()
484        tmp_lines = tmp_buf.split('\n')
485       
486        for tmp_line in tmp_lines:
487           
488            # Catch class name
489            newline = self.replaceToken(tmp_line, 
490                                        "[CPYTHONCLASS]", 
491                                        'C' + self.pythonClass)
492           
493            # Catch class name
494            newline = self.replaceToken(newline, 
495                                        "[PYTHONCLASS]", self.pythonClass)
496           
497            # Include file
498            newline = self.replaceToken(newline, 
499                                        "[INCLUDE_FILE]", self.file)   
500                   
501            # Include file
502            newline = self.replaceToken(newline, 
503                                        "[DEFAULT_LIST]", self.default_list)
504            # model description
505            newline = self.replaceToken(newline, 
506                                        "[DESCRIPTION]", self.description)
507            # Parameter details
508            newline = self.replaceToken(newline, 
509                                        "[PAR_DETAILS]", self.details)
510           
511            # Call base constructor
512            if self.is_multifunc:
513                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
514                    'C' + self.pythonClass + \
515                    ".__init__(self,multfactor)\n\tself.is_multifunc = True")
516                newline = self.replaceToken(newline,"[MULTIPLICITY_INFO]", 
517                                            self.multiplicity_info)
518            else:
519                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
520                    'C' + self.pythonClass + \
521                    ".__init__(self)\n        self.is_multifunc = False")
522                newline = self.replaceToken(newline, 
523                                            "[MULTIPLICITY_INFO]", "None")
524
525           
526            # fixed list  details
527            fixed_str = str(self.fixed)
528            fixed_str = fixed_str.replace(', ', ',\n                      ')
529            newline = self.replaceToken(newline, "[FIXED]", fixed_str)
530           
531            # non-fittable list details
532            pars_str = str(self.non_fittable)
533            pars_str = pars_str.replace(', ', 
534                                        ',\n                             ')
535            newline = self.replaceToken(newline, 
536                                        "[NON_FITTABLE_PARAMS]", pars_str)
537           
538            ## parameters with orientation
539            oriented_str = str(self.orientation_params)
540            formatted_endl = ',\n                                   '
541            oriented_str = oriented_str.replace(', ', formatted_endl)
542            newline = self.replaceToken(newline, 
543                               "[ORIENTATION_PARAMS]", oriented_str)
544           ## parameters with magnetism
545            newline = self.replaceToken(newline, 
546                               "[MAGNETIC_PARAMS]", str(self.magnetic_params))
547
548            if self.category:
549                newline = self.replaceToken(newline, "[CATEGORY]", 
550                                            '"' + self.category + '"')
551            else:
552                newline = self.replaceToken(newline, "[CATEGORY]",
553                                            "None")
554           
555
556
557            # Write new line to the wrapper .c file
558            file.write(newline+'\n')
559               
560        file.close()
561       
562       
563    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
564        """ Replace a token in the template file
565            @param line: line of text to inspect
566            @param key: token to look for
567            @param value: string value to replace the token with
568            @return: new string value
569        """
570        lenkey = len(key)
571        newline = line
572       
573        while newline.count(key)>0:
574            index = newline.index(key)
575            newline = newline[:index]+value+newline[index+lenkey:]
576       
577        return newline
578       
579    def getModelName(self):
580        return self.pythonClass
581       
582
583
584# main
585if __name__ == '__main__':
586    if len(sys.argv)>1:
587        print "Will look for file %s" % sys.argv[1]
588        app = WrapperGenerator(sys.argv[1])
589    else:
590        app = WrapperGenerator("test.h")
591    app.read()
592    app.write_c_wrapper()
593    app.write_python_wrapper()
594    print app
595   
596# End of file       
Note: See TracBrowser for help on using the repository browser.