source: sasview/src/sas/models/c_extension/python_wrapper/WrapperGenerator.py @ be0c318

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since be0c318 was 13e46abe, checked in by Doucet, Mathieu <doucetm@…>, 10 years ago

Clean up pylint

  • Property mode set to 100644
File size: 23.1 KB
Line 
1#!/usr/bin/env python
2""" WrapperGenerator class to generate model code automatically.
3"""
4
5import os, sys, re
6import lineparser
7
8class WrapperGenerator(object):
9    """ Python wrapper generator for C models
10
11        The developer must provide a header file describing
12        the new model.
13
14        To provide the name of the Python class to be
15        generated, the .h file must contain the following
16        string in the comments:
17
18        // [PYTHONCLASS] = my_model
19
20        where my_model must be replaced by the name of the
21        class that you want to import from sas.models.
22        (example: [PYTHONCLASS] = MyModel
23          will create a class MyModel in sas.models.MyModel.
24          It will also create a class CMyModel in
25          sas_extension.c_models.)
26
27        Also in comments, each parameter of the params
28        dictionary must be declared with a default value
29        in the following way:
30
31        // [DEFAULT]=param_name=default_value
32
33        (example:
34            //  [DEFAULT]=radius=20.0
35        )
36
37        See cylinder.h for an example.
38
39        A .c file corresponding to the .h file should also
40        be provided (example: my_model.h, my_model.c).
41
42        The .h file should define two function definitions. For example,
43        cylinder.h defines the following:
44
45        /// 1D scattering function
46        double cylinder_analytical_1D(CylinderParameters *pars, double q);
47
48        /// 2D scattering function
49        double cylinder_analytical_2D(CylinderParameters *pars, double q, double phi);
50
51        The .c file implements those functions.
52
53        @author: Mathieu Doucet / UTK
54        @contact: mathieu.doucet@nist.gov
55    """
56
57    def __init__(self, filename, output_dir='.', c_wrapper_dir='.'):
58        """ Initialization """
59
60        ## Name of .h file to generate wrapper from
61        self.file = filename
62
63        # Info read from file
64
65        ## Name of python class to write
66        self.pythonClass = None
67        ## Parser in struct section
68        self.inStruct = False
69        self.foundCPP = False
70        self.inParDefs = False
71        ## Name of struct for the c object
72        self.structName = None
73        ## Dictionary of parameters
74        self.params = {}
75        ## ModelCalculation module flag
76        self.modelCalcFlag = False
77        ## List of default parameters (text)
78        self.default_list = ""
79        ## Dictionary of units
80        self.details = ""
81        ## List of dispersed parameters
82        self.disp_params = []
83        #model description
84        self.description = ''
85        # paramaters for fittable
86        self.fixed = []
87        # paramaters for non-fittable
88        self.non_fittable = []
89        ## parameters with orientation
90        self.orientation_params = []
91        ## parameter with magnetism
92        self.magentic_params = []
93        # Model category
94        self.category = None
95        # Whether model belongs to multifunc
96        self.is_multifunc = False
97        ## output directory for wrappers
98        self.output_dir = output_dir
99        self.c_wrapper_dir = c_wrapper_dir
100
101    def __repr__(self):
102        """ Simple output for printing """
103
104        rep = "\n Python class: %s\n\n" % self.pythonClass
105        rep += "  struc name: %s\n\n" % self.structName
106        rep += "  params:     %s\n\n" % self.params
107        rep += "  description:    %s\n\n" % self.description
108        rep += "  Fittable parameters:     %s\n\n" % self.fixed
109        rep += "  Non-Fittable parameters:     %s\n\n" % self.non_fittable
110        rep += "  Orientation parameters:  %s\n\n" % self.orientation_params
111        rep += "  Magnetic parameters:  %s\n\n" % self.magnetic_params
112        return rep
113
114    def read(self):
115        """ Reads in the .h file to catch parameters of the wrapper """
116
117        # Check if the file is there
118        if not os.path.isfile(self.file):
119            raise ValueError, "File %s is not a regular file" % self.file
120
121        # Read file
122        f = open(self.file, 'r')
123        buf = f.read()
124
125        self.default_list = "\n    List of default parameters:\n\n"
126        #lines = string.split(buf,'\n')
127        lines = buf.split('\n')
128        self.details = "## Parameter details [units, min, max]\n"
129        self.details += "        self.details = {}\n"
130
131        #open item in this case Fixed
132        text = 'text'
133        key2 = "<%s>" % text.lower()
134        # close an item in this case fixed
135        text = 'TexT'
136        key3 = "</%s>" % text.lower()
137
138        ## Catch fixed parameters
139        key = "[FIXED]"
140        try:
141            self.fixed = lineparser.readhelper(lines, key,
142                                               key2, key3, file=self.file)
143        except:
144            raise
145        ## Catch non-fittable parameters parameters
146        key = "[NON_FITTABLE_PARAMS]"
147        try:
148            self.non_fittable = lineparser.readhelper(lines, key, key2,
149                                                      key3, file=self.file)
150        except:
151            raise
152
153        ## Catch parameters with orientation
154        key = "[ORIENTATION_PARAMS]"
155        try:
156            self.orientation_params = lineparser.readhelper(lines, key,
157                                                            key2, key3, file=self.file)
158        except:
159            raise
160
161        ## Catch parameters with orientation
162        key = "[MAGNETIC_PARAMS]"
163        try:
164            self.magnetic_params = lineparser.readhelper(lines, key,
165                                                         key2, key3, file=self.file)
166        except:
167            raise
168
169        ## Catch Description
170        key = "[DESCRIPTION]"
171
172        find_description = False
173        temp = ""
174        for line in lines:
175            if line.count(key) > 0:
176                try:
177                    find_description = True
178                    index = line.index(key)
179                    toks = line[index:].split("=", 1)
180                    temp = toks[1].lstrip().rstrip()
181                    text = 'text'
182                    key2 = "<%s>" % text.lower()
183                    if re.match(key2, temp) != None:
184
185                        toks2 = temp.split(key2, 1)
186                        self.description = toks2[1]
187                        text = 'text'
188                        key2 = "</%s>" % text.lower()
189                        if re.search(key2, toks2[1]) != None:
190                            temp = toks2[1].split(key2, 1)
191                            self.description = temp[0]
192                            break
193
194                    else:
195                        self.description = temp
196                        break
197                except:
198                    raise ValueError, "Could not parse file %s" % self.file
199            elif find_description:
200                text = 'text'
201                key2 = "</%s>" % text.lower()
202                if re.search(key2, line) != None:
203                    tok = line.split(key2, 1)
204                    temp = tok[0].split("//", 1)
205                    self.description += tok[1].lstrip().rstrip()
206                    break
207                else:
208                    if re.search("//", line) != None:
209                        temp = line.split("//", 1)
210                        self.description += '\n\t\t' + temp[1].lstrip().rstrip()
211
212                    else:
213                        self.description += '\n\t\t' + line.lstrip().rstrip()
214
215
216
217        for line in lines:
218
219            # Catch class name
220            key = "[PYTHONCLASS]"
221            if line.count(key) > 0:
222                try:
223                    index = line.index(key)
224                    toks = line[index:].split("=")
225                    self.pythonClass = toks[1].lstrip().rstrip()
226
227                except:
228                    raise ValueError, "Could not parse file %s" % self.file
229
230            key = "[CATEGORY]"
231            if line.count(key) > 0:
232                try:
233                    index = line.index(key)
234                    toks = line[index:].split("=")
235                    self.category = toks[1].lstrip().rstrip()
236
237                except:
238                    raise ValueError, "Could not parse file %s" % self.file
239
240            # is_multifunc
241            key = "[MULTIPLICITY_INFO]"
242            if line.count(key) > 0:
243                self.is_multifunc = True
244                try:
245                    index = line.index(key)
246                    toks = line[index:].split("=")
247                    self.multiplicity_info = toks[1].lstrip().rstrip()
248                except:
249                    raise ValueError, "Could not parse file %s" % self.file
250
251            # Catch struct name
252            # C++ class definition
253            if line.count("class") > 0:
254                # We are entering a class definition
255                self.inParDefs = True
256                self.foundCPP = True
257
258            # Old-Style C struct definition
259            if line.count("typedef struct") > 0:
260                # We are entering a struct block
261                self.inParDefs = True
262                self.inStruct = True
263
264            if self.inParDefs and line.count("}") > 0:
265                # We are exiting a struct block
266                self.inParDefs = False
267
268                if self.inStruct:
269                    self.inStruct = False
270                    # Catch the name of the struct
271                    index = line.index("}")
272                    toks = line[index + 1:].split(";")
273                    # Catch pointer definition
274                    toks2 = toks[0].split(',')
275                    self.structName = toks2[0].lstrip().rstrip()
276
277            # Catch struct content
278            key = "[DEFAULT]"
279            if self.inParDefs and line.count(key) > 0:
280                # Found a new parameter
281                try:
282                    index = line.index(key)
283                    toks = line[index:].split("=")
284                    toks2 = toks[2].split()
285                    val = float(toks2[0])
286                    self.params[toks[1]] = val
287                    #self.pythonClass = toks[1].lstrip().rstrip()
288                    units = ""
289                    if len(toks2) >= 2:
290                        units = toks2[1]
291                    self.default_list += "    * %-15s = %s %s\n" % \
292                        (toks[1], val, units)
293
294                    # Check for min and max
295                    value_min = "None"
296                    value_max = "None"
297                    if len(toks2) == 4:
298                        value_min = toks2[2]
299                        value_max = toks2[3]
300
301                    self.details += "        self.details['%s'] = ['%s', %s, %s]\n" % \
302                        (toks[1].lstrip().rstrip(), units.lstrip().rstrip(), value_min, value_max)
303                except:
304                    raise ValueError, "Could not parse input file %s \n  %s" % \
305                        (self.file, sys.exc_value)
306
307
308            # Catch need for numerical calculations
309            key = "CalcParameters calcPars"
310            if line.count(key) > 0:
311                self.modelCalcFlag = True
312
313            # Catch list of dispersed parameters
314            key = "[DISP_PARAMS]"
315            if line.count(key) > 0:
316                try:
317                    index = line.index(key)
318                    toks = line[index:].split("=")
319                    list_str = toks[1].lstrip().rstrip()
320                    self.disp_params = list_str.split(',')
321                except:
322                    raise ValueError, "Could not parse file %s" % self.file
323
324    def write_c_wrapper(self):
325        """ Writes the C file to create the python extension class
326            The file is written in C[PYTHONCLASS].c
327        """
328        file_path = os.path.join(self.c_wrapper_dir,
329                                 "C" + self.pythonClass + '.cpp')
330        file = open(file_path, 'w')
331
332        template = open(os.path.join(os.path.dirname(__file__),
333                                     "classTemplate.txt"), 'r')
334
335        tmp_buf = template.read()
336        #tmp_lines = string.split(tmp_buf,'\n')
337        tmp_lines = tmp_buf.split('\n')
338
339        for tmp_line in tmp_lines:
340
341            # Catch class name
342            newline = self.replaceToken(tmp_line,
343                                        "[PYTHONCLASS]", 'C' + self.pythonClass)
344            #Catch model description
345            #newline = self.replaceToken(tmp_line,
346            #                            "[DESCRIPTION]", self.description)
347            # Catch C model name
348            newline = self.replaceToken(newline,
349                                        "[CMODEL]", self.pythonClass)
350
351            # Catch class name
352            newline = self.replaceToken(newline,
353                                        "[MODELSTRUCT]", self.structName)
354
355            # Sort model initialization based on multifunc
356            if self.is_multifunc:
357                line = "int level = 1;\nPyArg_ParseTuple(args,\"i\",&level);\n"
358                line += "self->model = new " + self.pythonClass + "(level);"
359            else:
360                line = "self->model = new " + self.pythonClass + "();"
361
362            newline = self.replaceToken(newline, "[INITIALIZE_MODEL]", line)
363
364            # Dictionary initialization
365            param_str = "// Initialize parameter dictionary\n"
366            for par in self.params:
367                param_str += "        PyDict_SetItemString(self->params,\"%s\",Py_BuildValue(\"d\",%10.12f));\n" % \
368                    (par, self.params[par])
369
370            if len(self.disp_params) > 0:
371                param_str += "        // Initialize dispersion / averaging parameter dict\n"
372                param_str += "        DispersionVisitor* visitor = new DispersionVisitor();\n"
373                param_str += "        PyObject * disp_dict;\n"
374                for par in self.disp_params:
375                    par = par.strip()
376                    if par == '':
377                        continue
378                    param_str += "        disp_dict = PyDict_New();\n"
379                    param_str += "        self->model->%s.dispersion->accept_as_source(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
380                    param_str += "        PyDict_SetItemString(self->dispersion, \"%s\", disp_dict);\n" % par
381
382            # Initialize dispersion object dictionnary
383            param_str += "\n"
384
385
386            newline = self.replaceToken(newline,
387                                        "[INITDICTIONARY]", param_str)
388
389            # Read dictionary
390            param_str = "    // Reader parameter dictionary\n"
391            for par in self.params:
392                param_str += "    self->model->%s = PyFloat_AsDouble( PyDict_GetItemString(self->params, \"%s\") );\n" % \
393                    (par, par)
394
395            if len(self.disp_params) > 0:
396                param_str += "    // Read in dispersion parameters\n"
397                param_str += "    PyObject* disp_dict;\n"
398                param_str += "    DispersionVisitor* visitor = new DispersionVisitor();\n"
399                for par in self.disp_params:
400                    par = par.strip()
401                    if par == '':
402                        continue
403                    param_str += "    disp_dict = PyDict_GetItemString(self->dispersion, \"%s\");\n" % par
404                    param_str += "    self->model->%s.dispersion->accept_as_destination(visitor, self->model->%s.dispersion, disp_dict);\n" % (par, par)
405
406            newline = self.replaceToken(newline, "[READDICTIONARY]", param_str)
407
408            # Name of .c file
409            #toks = string.split(self.file,'.')
410            basename = os.path.basename(self.file)
411            toks = basename.split('.')
412            newline = self.replaceToken(newline, "[C_FILENAME]", toks[0])
413
414            # Include file
415            basename = os.path.basename(self.file)
416            newline = self.replaceToken(newline,
417                                        "[INCLUDE_FILE]", self.file)
418            if self.foundCPP:
419                newline = self.replaceToken(newline, "[C_INCLUDE_FILE]", "")
420                newline = self.replaceToken(newline,
421                                            "[CPP_INCLUDE_FILE]",
422                                            "#include \"%s\"" % basename)
423            else:
424                newline = self.replaceToken(newline,
425                                            "[C_INCLUDE_FILE]",
426                                            "#include \"%s\"" % basename)
427                newline = self.replaceToken(newline,
428                                            "[CPP_INCLUDE_FILE]",
429                                            "#include \"models.hh\"")
430
431            # Numerical calcs dealloc
432            dealloc_str = "\n"
433            if self.modelCalcFlag:
434                dealloc_str = "    modelcalculations_dealloc(&(self->model_pars.calcPars));\n"
435            newline = self.replaceToken(newline,
436                                        "[NUMERICAL_DEALLOC]", dealloc_str)
437
438            # Numerical calcs init
439            init_str = "\n"
440            if self.modelCalcFlag:
441                init_str = "        modelcalculations_init(&(self->model_pars.calcPars));\n"
442            newline = self.replaceToken(newline,
443                                        "[NUMERICAL_INIT]", init_str)
444
445            # Numerical calcs reset
446            reset_str = "\n"
447            if self.modelCalcFlag:
448                reset_str = "modelcalculations_reset(&(self->model_pars.calcPars));\n"
449            newline = self.replaceToken(newline,
450                                        "[NUMERICAL_RESET]", reset_str)
451
452            # Setting dispsertion weights
453            set_weights = "    // Ugliness necessary to go from python to C\n"
454            set_weights = "    // TODO: refactor this\n"
455            for par in self.disp_params:
456                par = par.strip()
457                if par == '':
458                        continue
459                set_weights += "    if (!strcmp(par_name, \"%s\")) {\n" % par
460                set_weights += "        self->model->%s.dispersion = dispersion;\n" % par
461                set_weights += "    } else"
462            newline = self.replaceToken(newline,
463                                        "[SET_DISPERSION]", set_weights)
464
465            # Write new line to the wrapper .c file
466            file.write(newline + '\n')
467
468
469        file.close()
470
471    def write_python_wrapper(self):
472        """ Writes the python file to create the python extension class
473            The file is written in ../[PYTHONCLASS].py
474        """
475        file_path = os.path.join(self.output_dir, self.pythonClass + '.py')
476        file = open(file_path, 'w')
477        template = open(os.path.join(os.path.dirname(__file__),
478                                     "modelTemplate.txt"), 'r')
479
480        tmp_buf = template.read()
481        tmp_lines = tmp_buf.split('\n')
482
483        for tmp_line in tmp_lines:
484
485            # Catch class name
486            newline = self.replaceToken(tmp_line,
487                                        "[CPYTHONCLASS]",
488                                        'C' + self.pythonClass)
489
490            # Catch class name
491            newline = self.replaceToken(newline,
492                                        "[PYTHONCLASS]", self.pythonClass)
493
494            # Include file
495            newline = self.replaceToken(newline,
496                                        "[INCLUDE_FILE]", self.file)
497
498            # Include file
499            newline = self.replaceToken(newline,
500                                        "[DEFAULT_LIST]", self.default_list)
501            # model description
502            newline = self.replaceToken(newline,
503                                        "[DESCRIPTION]", self.description)
504            # Parameter details
505            newline = self.replaceToken(newline,
506                                        "[PAR_DETAILS]", self.details)
507
508            # Call base constructor
509            if self.is_multifunc:
510                newline = self.replaceToken(newline, "[CALL_CPYTHON_INIT]",
511                                            'C' + self.pythonClass + \
512                    ".__init__(self,multfactor)\n\tself.is_multifunc = True")
513                newline = self.replaceToken(newline, "[MULTIPLICITY_INFO]",
514                                            self.multiplicity_info)
515            else:
516                newline = self.replaceToken(newline, "[CALL_CPYTHON_INIT]",
517                                            'C' + self.pythonClass + \
518                    ".__init__(self)\n        self.is_multifunc = False")
519                newline = self.replaceToken(newline,
520                                            "[MULTIPLICITY_INFO]", "None")
521
522            # fixed list  details
523            fixed_str = str(self.fixed)
524            fixed_str = fixed_str.replace(', ', ',\n                      ')
525            newline = self.replaceToken(newline, "[FIXED]", fixed_str)
526
527            # non-fittable list details
528            pars_str = str(self.non_fittable)
529            pars_str = pars_str.replace(', ',
530                                        ',\n                             ')
531            newline = self.replaceToken(newline,
532                                        "[NON_FITTABLE_PARAMS]", pars_str)
533
534            ## parameters with orientation
535            oriented_str = str(self.orientation_params)
536            formatted_endl = ',\n                                   '
537            oriented_str = oriented_str.replace(', ', formatted_endl)
538            newline = self.replaceToken(newline,
539                                        "[ORIENTATION_PARAMS]", oriented_str)
540            ## parameters with magnetism
541            newline = self.replaceToken(newline,
542                                        "[MAGNETIC_PARAMS]",
543                                        str(self.magnetic_params))
544
545            if self.category:
546                newline = self.replaceToken(newline, "[CATEGORY]",
547                                            '"' + self.category + '"')
548            else:
549                newline = self.replaceToken(newline, "[CATEGORY]",
550                                            "None")
551
552
553
554            # Write new line to the wrapper .c file
555            file.write(newline + '\n')
556
557        file.close()
558
559
560    def replaceToken(self, line, key, value): #pylint: disable-msg=R0201
561        """ Replace a token in the template file
562            @param line: line of text to inspect
563            @param key: token to look for
564            @param value: string value to replace the token with
565            @return: new string value
566        """
567        lenkey = len(key)
568        newline = line
569
570        while newline.count(key) > 0:
571            index = newline.index(key)
572            newline = newline[:index] + value + newline[index + lenkey:]
573
574        return newline
575
576    def getModelName(self):
577        return self.pythonClass
578
579
580
581# main
582if __name__ == '__main__':
583    if len(sys.argv) > 1:
584        print "Will look for file %s" % sys.argv[1]
585        app = WrapperGenerator(sys.argv[1])
586    else:
587        app = WrapperGenerator("test.h")
588    app.read()
589    app.write_c_wrapper()
590    app.write_python_wrapper()
591    print app
592
593# End of file       
Note: See TracBrowser for help on using the repository browser.