source: sasview/sansmodels/src/python_wrapper/WrapperGenerator.py @ 5ebdf14

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 5ebdf14 was 318b5bbb, checked in by Jae Cho <jhjcho@…>, 12 years ago

Added polarization and magnetic stuffs

  • Property mode set to 100644
File size: 23.6 KB
RevLine 
[af03ddd]1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
[95986b5]5import os, sys,re
[25a608f5]6import lineparser
[af03ddd]7
8class WrapperGenerator:
9    """ Python wrapper generator for C models
10   
11        The developer must provide a header file describing
12        the new model.
13       
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17       
18        // [PYTHONCLASS] = my_model
19       
20        where my_model must be replaced by the name of the
21        class that you want to import from sans.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sans.models.MyModel.
24          It will also create a class CMyModel in
25          sans_extension.c_models.)
26         
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30       
31        // [DEFAULT]=param_name=default_value
32       
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36         
37        See cylinder.h for an example.
38       
39       
40        A .c file corresponding to the .h file should also
41        be provided (example: my_model.h, my_model.c).
42   
43        The .h file should define two function definitions. For example,
44        cylinder.h defines the following:
45       
[8389beb]46        /// 1D scattering function
47        double cylinder_analytical_1D(CylinderParameters *pars, double q);
[af03ddd]48           
[8389beb]49        /// 2D scattering function
50        double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
[af03ddd]51           
52        The .c file implements those functions.
53       
54        @author: Mathieu Doucet / UTK
55        @contact: mathieu.doucet@nist.gov
56    """
57   
[642a025]58    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
[af03ddd]59        """ Initialization """
60       
61        ## Name of .h file to generate wrapper from
62        self.file = filename
63       
64        # Info read from file
65       
66        ## Name of python class to write
67        self.pythonClass = None
68        ## Parser in struct section
69        self.inStruct = False
[a1d1b90]70        self.foundCPP = False
[d62f422]71        self.inParDefs = False
[af03ddd]72        ## Name of struct for the c object
73        self.structName = None
74        ## Dictionary of parameters
75        self.params = {}
76        ## ModelCalculation module flag
77        self.modelCalcFlag = False
78        ## List of default parameters (text)
79        self.default_list = ""
80        ## Dictionary of units
81        self.details = ""
82        ## List of dispersed parameters
83        self.disp_params = []
[4e2f6ef8]84        #model description
85        self.description=''
[836fe6e]86        # paramaters for fittable
87        self.fixed= []
[35aface]88        # paramaters for non-fittable
89        self.non_fittable= []
[25a608f5]90        ## parameters with orientation
91        self.orientation_params =[]
[318b5bbb]92        ## parameter with magnetism
93        self.magentic_params = []
[fa6db8b]94        # Model category
95        self.category = None
[5da3cc5]96        # Whether model belongs to multifunc
97        self.is_multifunc = False
[2d1b700]98        ## output directory for wrappers
99        self.output_dir = output_dir
[642a025]100        self.c_wrapper_dir = c_wrapper_dir
[0a9686d]101
[25a608f5]102       
[af03ddd]103       
104    def __repr__(self):
105        """ Simple output for printing """
106       
[25a608f5]107        rep  = "\n Python class: %s\n\n" % self.pythonClass
108        rep += "  struc name: %s\n\n" % self.structName
109        rep += "  params:     %s\n\n" % self.params
110        rep += "  description:    %s\n\n" % self.description
111        rep += "  Fittable parameters:     %s\n\n"% self.fixed
[35aface]112        rep += "  Non-Fittable parameters:     %s\n\n"% self.non_fittable
[25a608f5]113        rep += "  Orientation parameters:  %s\n\n"% self.orientation_params
[318b5bbb]114        rep += "  Magnetic parameters:  %s\n\n"% self.magnetic_params
[af03ddd]115        return rep
116       
117    def read(self):
118        """ Reads in the .h file to catch parameters of the wrapper """
119       
120        # Check if the file is there
121        if not os.path.isfile(self.file):
122            raise ValueError, "File %s is not a regular file" % self.file
123       
124        # Read file
125        f = open(self.file,'r')
126        buf = f.read()
127       
128        self.default_list = "List of default parameters:\n"
129        #lines = string.split(buf,'\n')
130        lines = buf.split('\n')
131        self.details  = "## Parameter details [units, min, max]\n"
132        self.details += "        self.details = {}\n"
[25a608f5]133       
[836fe6e]134        #open item in this case Fixed
135        text='text'
136        key2="<%s>"%text.lower()
137        # close an item in this case fixed
[da3dae3]138        text='TexT'
[836fe6e]139        key3="</%s>"%text.lower()
140       
[25a608f5]141        ## Catch fixed parameters
142        key = "[FIXED]"
143        try:
[79ba1fc]144            self.fixed= lineparser.readhelper(lines, key, 
145                                              key2, key3, file=self.file)
[25a608f5]146        except:
147           raise   
[35aface]148        ## Catch non-fittable parameters parameters
149        key = "[NON_FITTABLE_PARAMS]"
150        try:
[79ba1fc]151            self.non_fittable= lineparser.readhelper(lines, key, key2,
152                                                     key3, file=self.file)
[35aface]153        except:
154           raise   
155
[25a608f5]156        ## Catch parameters with orientation
157        key = "[ORIENTATION_PARAMS]"   
158        try:
[79ba1fc]159            self.orientation_params = lineparser.readhelper(lines, key, 
160                                                    key2, key3, file=self.file)
[25a608f5]161        except:
162           raise 
[318b5bbb]163       
164        ## Catch parameters with orientation
165        key = "[MAGNETIC_PARAMS]"   
166        try:
167            self.magnetic_params = lineparser.readhelper( lines,key, 
168                                                    key2,key3, file= self.file)
169        except:
170           raise 
171       
[25a608f5]172        ## Catch Description
[96672c0]173        key = "[DESCRIPTION]"
[25a608f5]174       
175        find_description = False
[96672c0]176        temp=""
177        for line in lines:
178            if line.count(key)>0 :
[9316609]179               
[96672c0]180                try:
[25a608f5]181                    find_description= True
[96672c0]182                    index = line.index(key)
183                    toks = line[index:].split("=",1 )
184                    temp=toks[1].lstrip().rstrip()
185                    text='text'
[9316609]186                    key2="<%s>"%text.lower()
[96672c0]187                    if re.match(key2,temp)!=None:
[25a608f5]188   
[9316609]189                        toks2=temp.split(key2,1)
190                        self.description=toks2[1]
191                        text='text'
192                        key2="</%s>"%text.lower()
193                        if re.search(key2,toks2[1])!=None:
194                            temp=toks2[1].split(key2,1)
195                            self.description=temp[0]
196                            break
[25a608f5]197                     
[96672c0]198                    else:
199                        self.description=temp
[9316609]200                        break
[96672c0]201                except:
202                     raise ValueError, "Could not parse file %s" % self.file
[25a608f5]203            elif find_description:
[9316609]204                text='text'
205                key2="</%s>"%text.lower()
206                if re.search(key2,line)!=None:
207                    tok=line.split(key2,1)
208                    temp=tok[0].split("//",1)
209                    self.description+=tok[1].lstrip().rstrip()
210                    break
211                else:
212                    if re.search("//",line)!=None:
213                        temp=line.split("//",1)
214                        self.description+='\n\t\t'+temp[1].lstrip().rstrip()
215                       
216                    else:
217                        self.description+='\n\t\t'+line.lstrip().rstrip()
218                   
219               
[0a9686d]220     
[af03ddd]221        for line in lines:
222           
223            # Catch class name
224            key = "[PYTHONCLASS]"
225            if line.count(key)>0:
226                try:
227                    index = line.index(key)
228                    toks = line[index:].split("=" )
229                    self.pythonClass = toks[1].lstrip().rstrip()
[0a9686d]230
[af03ddd]231                except:
232                    raise ValueError, "Could not parse file %s" % self.file
233               
[fa6db8b]234            key = "[CATEGORY]"
235            if line.count(key)>0:
236                try:
237                    index = line.index(key)
238                    toks = line[index:].split("=")
239                    self.category = toks[1].lstrip().rstrip()
240
241                except:
242                    raise ValueError, "Could not parse file %s" % self.file
243
[5da3cc5]244            # is_multifunc
245            key = "[MULTIPLICITY_INFO]"
246            if line.count(key) > 0:
247                self.is_multifunc = True
248                try:
249                    index = line.index(key)
250                    toks = line[index:].split("=")
251                    self.multiplicity_info = toks[1].lstrip().rstrip()
252                except:
253                    raise ValueError, "Could not parse file %s" % self.file
[fa6db8b]254
[af03ddd]255            # Catch struct name
[a1d1b90]256            # C++ class definition
257            if line.count("class")>0:
258                # We are entering a class definition
[d62f422]259                self.inParDefs = True
[a1d1b90]260                self.foundCPP = True
261               
262            # Old-Style C struct definition
[af03ddd]263            if line.count("typedef struct")>0:
264                # We are entering a struct block
[d62f422]265                self.inParDefs = True
[af03ddd]266                self.inStruct = True
267           
[d62f422]268            if self.inParDefs and line.count("}")>0:
[af03ddd]269                # We are exiting a struct block
[d62f422]270                self.inParDefs = False
[af03ddd]271               
[d62f422]272                if self.inStruct:
273                    self.inStruct = False
274                    # Catch the name of the struct
275                    index = line.index("}")
276                    toks = line[index+1:].split(";")
277                    # Catch pointer definition
278                    toks2 = toks[0].split(',')
279                    self.structName = toks2[0].lstrip().rstrip()
[96672c0]280           
[af03ddd]281            # Catch struct content
282            key = "[DEFAULT]"
[d62f422]283            if self.inParDefs and line.count(key)>0:
[af03ddd]284                # Found a new parameter
285                try:
286                    index = line.index(key)
287                    toks = line[index:].split("=")
288                    toks2 = toks[2].split()
289                    val = float(toks2[0])
290                    self.params[toks[1]] = val
291                    #self.pythonClass = toks[1].lstrip().rstrip()
292                    units = ""
293                    if len(toks2) >= 2:
294                        units = toks2[1]
295                    self.default_list += "         %-15s = %s %s\n" % \
296                        (toks[1], val, units)
297                   
298                    # Check for min and max
299                    min = "None"
300                    max = "None"
301                    if len(toks2) == 4:
302                        min = toks2[2]
303                        max = toks2[3]
304                   
305                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
306                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), min, max)
307                except:
308                    raise ValueError, "Could not parse input file %s \n  %s" % \
309                        (self.file, sys.exc_value)
310               
311               
312            # Catch need for numerical calculations
313            key = "CalcParameters calcPars"
314            if line.count(key)>0:
315                self.modelCalcFlag = True
316               
317            # Catch list of dispersed parameters
318            key = "[DISP_PARAMS]"
319            if line.count(key)>0:
320                try:
321                    index = line.index(key)
322                    toks = line[index:].split("=")
323                    list_str = toks[1].lstrip().rstrip()
324                    self.disp_params = list_str.split(',')
325                except:
326                    raise ValueError, "Could not parse file %s" % self.file
[0a9686d]327
[af03ddd]328    def write_c_wrapper(self):
329        """ Writes the C file to create the python extension class
330            The file is written in C[PYTHONCLASS].c
331        """
[8389beb]332        file_path = os.path.join(self.c_wrapper_dir, 
333                                 "C"+self.pythonClass+'.cpp')
[2d1b700]334        file = open(file_path, 'w')
[af03ddd]335       
[8389beb]336        template = open(os.path.join(os.path.dirname(__file__), 
337                                     "classTemplate.txt"), 'r')
[af03ddd]338       
339        tmp_buf = template.read()
340        #tmp_lines = string.split(tmp_buf,'\n')
341        tmp_lines = tmp_buf.split('\n')
342       
343        for tmp_line in tmp_lines:
344           
345            # Catch class name
346            newline = self.replaceToken(tmp_line, 
347                                        "[PYTHONCLASS]", 'C'+self.pythonClass)
[9316609]348            #Catch model description
[95986b5]349            #newline = self.replaceToken(tmp_line,
350            #                            "[DESCRIPTION]", self.description)
[af03ddd]351            # Catch C model name
352            newline = self.replaceToken(newline, 
353                                        "[CMODEL]", self.pythonClass)
354           
355            # Catch class name
356            newline = self.replaceToken(newline, 
357                                        "[MODELSTRUCT]", self.structName)
[5da3cc5]358
359            # Sort model initialization based on multifunc
360            if(self.is_multifunc):
361                line = "int level = 1;\nPyArg_ParseTuple(args,\"i\",&level);\n"
362                line += "self->model = new " + self.pythonClass + "(level);"
363            else:
364                line = "self->model = new " + self.pythonClass + "();"
365   
366            newline = self.replaceToken(newline,"[INITIALIZE_MODEL]",
367                                            line)
[af03ddd]368           
369            # Dictionary initialization
370            param_str = "// Initialize parameter dictionary\n"           
371            for par in self.params:
[35aface]372                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
[0f5bc9f]373                    (par, self.params[par])
[af03ddd]374
[1b758b3]375            if len(self.disp_params)>0:
376                param_str += "        // Initialize dispersion / averaging parameter dict\n"
377                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
378                param_str += "        PyObject * disp_dict;\n"
379                for par in self.disp_params:
380                    par = par.strip()
381                    param_str += "        disp_dict = PyDict_New();\n"
382                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
383                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
[af03ddd]384               
385            # Initialize dispersion object dictionnary
386            param_str += "\n"
387           
388               
389            newline = self.replaceToken(newline,
390                                        "[INITDICTIONARY]", param_str)
391           
392            # Read dictionary
393            param_str = "    // Reader parameter dictionary\n"
394            for par in self.params:
395                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
396                    (par, par)
397                   
[1b758b3]398            if len(self.disp_params)>0:
399                param_str += "    // Read in dispersion parameters\n"
400                param_str += "    PyObject* disp_dict;\n"
401                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
402                for par in self.disp_params:
403                    par = par.strip()
404                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
405                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
[af03ddd]406               
407            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
408               
409            # Name of .c file
410            #toks = string.split(self.file,'.')
411            basename = os.path.basename(self.file)
412            toks = basename.split('.')
413            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
414           
415            # Include file
416            basename = os.path.basename(self.file)
417            newline = self.replaceToken(newline, 
[a1d1b90]418                                        "[INCLUDE_FILE]", self.file) 
419            if self.foundCPP:
420                newline = self.replaceToken(newline, 
[8389beb]421                            "[C_INCLUDE_FILE]", "") 
[a1d1b90]422                newline = self.replaceToken(newline, 
[8389beb]423                            "[CPP_INCLUDE_FILE]", "#include \"%s\"" % basename) 
[a1d1b90]424            else: 
425                newline = self.replaceToken(newline, 
[8389beb]426                            "[C_INCLUDE_FILE]", "#include \"%s\"" % basename)   
[a1d1b90]427                newline = self.replaceToken(newline, 
[8389beb]428                            "[CPP_INCLUDE_FILE]", "#include \"models.hh\"") 
[af03ddd]429               
430            # Numerical calcs dealloc
431            dealloc_str = "\n"
432            if self.modelCalcFlag:
433                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
434            newline = self.replaceToken(newline, 
435                                        "[NUMERICAL_DEALLOC]", dealloc_str)     
436               
437            # Numerical calcs init
438            init_str = "\n"
439            if self.modelCalcFlag:
440                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
441            newline = self.replaceToken(newline, 
442                                        "[NUMERICAL_INIT]", init_str)     
443               
444            # Numerical calcs reset
445            reset_str = "\n"
446            if self.modelCalcFlag:
447                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
448            newline = self.replaceToken(newline, 
449                                        "[NUMERICAL_RESET]", reset_str)     
450               
451            # Setting dispsertion weights
452            set_weights = "    // Ugliness necessary to go from python to C\n"
453            set_weights = "    // TODO: refactor this\n"
454            for par in self.disp_params:
455                par = par.strip()
456                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
457                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
458                set_weights += "    } else"
459            newline = self.replaceToken(newline, 
460                                        "[SET_DISPERSION]", set_weights)     
461           
462            # Write new line to the wrapper .c file
463            file.write(newline+'\n')
464           
465           
466        file.close()
467       
468    def write_python_wrapper(self):
469        """ Writes the python file to create the python extension class
470            The file is written in ../[PYTHONCLASS].py
471        """
[642a025]472        file_path = os.path.join(self.output_dir, self.pythonClass+'.py')
[2d1b700]473        file = open(file_path, 'w')
[8389beb]474        template = open(os.path.join(os.path.dirname(__file__), 
475                                     "modelTemplate.txt"), 'r')
[af03ddd]476       
477        tmp_buf = template.read()
478        tmp_lines = tmp_buf.split('\n')
479       
480        for tmp_line in tmp_lines:
481           
482            # Catch class name
483            newline = self.replaceToken(tmp_line, 
[318b5bbb]484                                        "[CPYTHONCLASS]", 
485                                        'C' + self.pythonClass)
[af03ddd]486           
487            # Catch class name
488            newline = self.replaceToken(newline, 
489                                        "[PYTHONCLASS]", self.pythonClass)
490           
491            # Include file
492            newline = self.replaceToken(newline, 
493                                        "[INCLUDE_FILE]", self.file)   
494                   
495            # Include file
496            newline = self.replaceToken(newline, 
497                                        "[DEFAULT_LIST]", self.default_list)
[95986b5]498            # model description
[af03ddd]499            newline = self.replaceToken(newline, 
[95986b5]500                                        "[DESCRIPTION]", self.description)
[4e2f6ef8]501            # Parameter details
502            newline = self.replaceToken(newline, 
[95986b5]503                                        "[PAR_DETAILS]", self.details)
[5da3cc5]504           
505            # Call base constructor
506            if self.is_multifunc:
507                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
[8389beb]508                    'C' + self.pythonClass + \
509                    ".__init__(self,multfactor)\n\tself.is_multifunc = True")
510                newline = self.replaceToken(newline,"[MULTIPLICITY_INFO]", 
511                                            self.multiplicity_info)
[5da3cc5]512            else:
513                newline = self.replaceToken(newline,"[CALL_CPYTHON_INIT]",
[8389beb]514                    'C' + self.pythonClass + \
515                    ".__init__(self)\n        self.is_multifunc = False")
516                newline = self.replaceToken(newline, 
[318b5bbb]517                                            "[MULTIPLICITY_INFO]", "None")
[5da3cc5]518
[95986b5]519           
[836fe6e]520            # fixed list  details
[79ba1fc]521            fixed_str = str(self.fixed)
[0ea247e]522            fixed_str = fixed_str.replace(', ', ',\n                      ')
[318b5bbb]523            newline = self.replaceToken(newline, "[FIXED]", fixed_str)
[79ba1fc]524           
525            # non-fittable list details
[7343319]526            pars_str = str(self.non_fittable)
[0ea247e]527            pars_str = pars_str.replace(', ', 
528                                        ',\n                             ')
[836fe6e]529            newline = self.replaceToken(newline, 
[318b5bbb]530                                        "[NON_FITTABLE_PARAMS]", pars_str)
[79ba1fc]531           
[25a608f5]532            ## parameters with orientation
[79ba1fc]533            oriented_str = str(self.orientation_params)
[0ea247e]534            formatted_endl = ',\n                                   '
[5c46f43]535            oriented_str = oriented_str.replace(', ', formatted_endl)
[25a608f5]536            newline = self.replaceToken(newline, 
[79ba1fc]537                               "[ORIENTATION_PARAMS]", oriented_str)
[318b5bbb]538           ## parameters with magnetism
539            newline = self.replaceToken(newline, 
540                               "[MAGNETIC_PARAMS]", str(self.magnetic_params))
541
[fa6db8b]542            if self.category:
[318b5bbb]543                newline = self.replaceToken(newline, "[CATEGORY]", 
544                                            '"' + self.category + '"')
[fa6db8b]545            else:
[318b5bbb]546                newline = self.replaceToken(newline, "[CATEGORY]",
[fa6db8b]547                                            "None")
548           
549
550
[af03ddd]551            # Write new line to the wrapper .c file
552            file.write(newline+'\n')
553               
554        file.close()
555       
556       
557    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
558        """ Replace a token in the template file
559            @param line: line of text to inspect
560            @param key: token to look for
561            @param value: string value to replace the token with
562            @return: new string value
563        """
564        lenkey = len(key)
565        newline = line
[836fe6e]566       
[af03ddd]567        while newline.count(key)>0:
568            index = newline.index(key)
569            newline = newline[:index]+value+newline[index+lenkey:]
[836fe6e]570       
[af03ddd]571        return newline
572       
[0a9686d]573    def getModelName(self):
574        return self.pythonClass
[af03ddd]575       
[0a9686d]576
577
[af03ddd]578# main
579if __name__ == '__main__':
580    if len(sys.argv)>1:
581        print "Will look for file %s" % sys.argv[1]
582        app = WrapperGenerator(sys.argv[1])
583    else:
584        app = WrapperGenerator("test.h")
585    app.read()
586    app.write_c_wrapper()
587    app.write_python_wrapper()
588    print app
589   
[0a9686d]590# End of file       
Note: See TracBrowser for help on using the repository browser.