source: sasview/sansmodels/src/sans/models/test/testcase_generator.py @ f629e346

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since f629e346 was ae3ce4e, checked in by Mathieu Doucet <doucetm@…>, 17 years ago

Moving sansmodels to trunk

  • Property mode set to 100644
File size: 22.2 KB
Line 
1"""
2    Random test-case generator for BaseComponents
3   
4    @author: Mathieu Doucet / UTK
5    @license: This software is provided as part of the DANSE project.
6"""
7import time
8
9def randomModel():
10    """ Return a random model name """
11    from sans.models.ModelFactory import ModelFactory
12    from random import random
13    from math import floor
14   
15    model_list = ModelFactory().getAllModels()
16    ready = False
17    while not ready:
18        rnd_id = int(floor(random()*len(model_list)))
19        if model_list[rnd_id] not in ["TestSphere2", "DataModel"]:
20            ready = True
21    return model_list[rnd_id]
22
23class ReportCard: # pylint: disable-msg=R0902
24    """ Class to hold test-case results """
25   
26    def __init__(self):
27        """ Initialization """
28        ## Number of test cases
29        self.n_cases = 0
30        ## Number of passed test cases
31        self.n_cases_pass = 0
32        ## Number of Evaluation calls
33        self.n_eval = 0
34        ## Number of passed Evaluation calls
35        self.n_eval_pass = 0
36        ## Number of SetParam calls
37        self.n_set = 0
38        ## Number of passed Set calls
39        self.n_set_pass = 0
40        ## Number of GetParam calls
41        self.n_get = 0
42        ## Number of passed Get calls
43        self.n_get_pass = 0
44        ## Number of load calls
45        self.n_load = 0
46        ## Number of passed Load calls
47        self.n_load_pass = 0
48        ## Number of save calls
49        self.n_save = 0
50        ## Number of passed Save calls
51        self.n_save_pass = 0
52        ## Number of Add calls
53        self.n_add = 0
54        ## Number of passed Add calls
55        self.n_add_pass = 0
56        ## Number of Sub calls
57        self.n_sub = 0
58        ## Number of passed Sub calls
59        self.n_sub_pass = 0
60        ## Number of Div calls
61        self.n_div = 0
62        ## Number of passed Div calls
63        self.n_div_pass = 0
64        ## Number of mul calls
65        self.n_mul = 0
66        ## Number of passed Mul calls
67        self.n_mul_pass = 0
68        ## Test log
69        self.log = ""
70        ## Trace
71        self.trace = ""
72        ## Model tested
73        self.model = ""
74        ## List of all models tested
75        self.modelList = []
76       
77       
78    def __add__(self, other):
79        """ Add two report cards
80            @param other: other report to add
81            @return: updated self
82        """
83       
84        self.n_cases += other.n_cases
85        self.n_cases_pass += other.n_cases_pass
86        self.n_eval += other.n_eval
87        self.n_eval_pass += other.n_eval_pass
88        self.n_set += other.n_set
89        self.n_set_pass += other.n_set_pass
90        self.n_get += other.n_get
91        self.n_get_pass += other.n_get_pass
92        self.n_load += other.n_load
93        self.n_load_pass += other.n_load_pass
94        self.n_save += other.n_save
95        self.n_save_pass += other.n_save_pass
96        self.n_add += other.n_add
97        self.n_add_pass += other.n_add_pass
98        self.n_sub += other.n_sub
99        self.n_sub_pass += other.n_sub_pass
100        self.n_mul += other.n_mul
101        self.n_mul_pass += other.n_mul_pass
102        self.n_div += other.n_div
103        self.n_div_pass += other.n_div_pass
104        if len(other.log)>0:
105            self.log += other.log
106        if len(other.trace)>0:
107            self.trace += other.trace
108           
109        if not other.model in self.modelList:
110            self.modelList.append(other.model)
111           
112        return self
113       
114    def isPassed(self):
115        """ Returns true if no error was found """
116        if self.n_cases_pass < self.n_cases \
117            or self.n_save_pass < self.n_save \
118            or self.n_load_pass < self.n_load \
119            or self.n_set_pass < self.n_set \
120            or self.n_get_pass < self.n_get \
121            or self.n_eval_pass < self.n_eval \
122            or self.n_add_pass < self.n_add \
123            or self.n_sub_pass < self.n_sub \
124            or self.n_mul_pass < self.n_mul \
125            or self.n_div_pass < self.n_div:
126            return False
127        return True
128       
129    def __str__(self):
130        """ String representation of the report card """
131        from sans.models.ModelFactory import ModelFactory
132       
133        rep = "Detailed output:\n"
134        rep += self.log
135        rep += "\n"
136        rep += "Total number of cases: %g\n" % self.n_cases
137        rep += "   Passed:             %g\n" % self.n_cases_pass
138        rep += "\n"
139        self.modelList.sort()
140        rep += "Models tested: %s\n" % self.modelList
141        all_models = ModelFactory().getAllModels()
142        all_models.sort()
143        rep += "\n"
144        rep += "Available models: %s\n" % all_models
145        rep += "\n"
146        rep += "Breakdown:\n"
147        rep += "   Eval:          %g / %g\n" % (self.n_eval_pass, self.n_eval)
148        rep += "   Set:           %g / %g\n" % (self.n_set_pass, self.n_set)
149        rep += "   Get:           %g / %g\n" % (self.n_get_pass, self.n_get)
150        rep += "   Load:          %g / %g\n" % (self.n_load_pass, self.n_load)
151        rep += "   Save:          %g / %g\n" % (self.n_save_pass, self.n_save)
152        rep += "   Add:           %g / %g\n" % (self.n_add_pass, self.n_add)
153        rep += "   Sub:           %g / %g\n" % (self.n_sub_pass, self.n_sub)
154        rep += "   Mul:           %g / %g\n" % (self.n_mul_pass, self.n_mul)
155        rep += "   Div:           %g / %g\n" % (self.n_div_pass, self.n_div)
156        return rep
157   
158       
159
160class TestCaseGenerator:
161    """ Generator for suite of test-cases
162    """
163   
164    def __init__(self):
165        """ Initialization
166        """
167       
168        self.n_tests = 0
169        self.n_passed = 0
170        self.time = 0
171        self.reportCard = ReportCard()
172   
173    def generateFiles(self, number, file_prefix):
174        """ Generate test-case files
175            @param number: number of files to generate
176            @param file_prefix: prefix for the file names
177        """
178       
179        for i in range(number):
180            filename = "%s_%d.xml" % (file_prefix, i)
181            self.generateFileTest(filename)
182            self.n_tests += 1
183           
184    def generateAndRun(self, number):
185        """ Generate test-cases and run them
186            @param number: number of test-cases to generate
187        """
188        start_time = time.time()
189        for i in range(number):
190            textcase = self.generateTest()
191            t = TestCase(text = textcase)
192            passed = t.run()
193            self.reportCard += t.reportCard
194            if not passed:
195                t = time.localtime()
196                xmloutput = open("error_%i-%i-%i-%i-%i-%i_%i.xml" % \
197                 (t[0],t[1],t[2],t[3],t[4],t[5],self.reportCard.n_cases),'w')
198                xmloutput.write(textcase)
199                xmloutput.close()
200
201               
202        self.time += time.time()-start_time
203        print self.reportCard       
204       
205    def generateFileTest(self, filename = "tmp.xml"):
206        """
207            Write a random test-case in an XML file
208            @param filename: name of file to write to
209        """
210        text = self.generateTest()
211        # Write test case in file
212        xmlfile = open(filename,'w')
213        xmlfile.write(text)
214        xmlfile.close()
215   
216   
217    def generateTest(self):
218        """
219            Generate an XML representation of a random test-case
220        """
221        import random
222       
223        #t = TestCase()
224        text = "<?xml version=\"1.0\"?>\n"
225        text  += "<TestCase>\n"
226        stimulus = "eval"
227   
228        loop = True
229        while loop:
230            rnd = random.random()
231           
232            # run 50%
233            if rnd <= 0.5:
234                stimulus = "eval"
235            elif rnd>0.5 and rnd<=0.7:
236                stimulus = "set"
237            elif rnd>0.7 and rnd<=0.72:
238                stimulus = "save"
239            elif rnd>0.72 and rnd<=0.74:
240                stimulus = "load"
241            elif rnd>0.74 and rnd<=0.8:
242                stimulus = "get"
243            elif rnd>0.8 and rnd<=0.81:
244                stimulus = "add"
245            elif rnd>0.81 and rnd<=0.82:
246                stimulus = "div"
247            elif rnd>0.82 and rnd<=0.83:
248                stimulus = "mul"
249            elif rnd>0.83 and rnd<=0.84:
250                stimulus = "sub"
251            else:
252                # just run and quit
253                stimulus = "eval"
254                loop = False
255               
256            text += "  <Stimulus id=\"%s\"/>\n" % stimulus
257        text += "</TestCase>"
258       
259        return text
260   
261       
262class TestCase:
263    """ Test-case class """
264   
265    def __init__(self, filename = None, text = None):
266        """ Initialization
267            @param filename: name of file containing the test case
268        """
269        #from sans.models.ModelFactory import ModelFactory
270        self.filename = filename
271        self.text = text
272        #self.model = ModelFactory().getModel(randomModel())
273        self.model = getRandomModelObject()
274        #self.model = ModelFactory().getModel("SphereModel")
275        self.passed = True
276        self.reportCard = ReportCard()
277       
278   
279    def run(self):
280        """ Read the test case and execute it """
281        from xml.dom.minidom import parse
282        from xml.dom.minidom import parseString
283       
284        # Initialize report
285        self.reportCard = ReportCard()
286        self.reportCard.model = self.model.name
287        self.reportCard.n_cases = 1
288       
289        if not self.text == None:
290            dom = parseString(self.text)
291        elif not self.filename == None:
292            dom = parse(self.filename)
293        else:
294            print "No input to parse"
295            return False
296       
297        if dom.hasChildNodes():
298            for n in dom.childNodes:
299                if n.nodeName == "TestCase":
300                    self.processStimuli(n)
301           
302        # Update report card       
303        if self.passed:
304            self.reportCard.n_cases_pass = 1
305           
306        return self.passed
307                         
308    def processStimuli(self, node):
309        """ Process the stimuli list in the TestCase node
310            of an XML test-case file
311            @param node: test-case node
312        """
313        import testcase_generator as generator
314        report = ReportCard()
315        report.trace = "%s\n" % self.model.name
316       
317        self.passed = True
318        if node.hasChildNodes():
319            for n in node.childNodes:
320                if n.nodeName == "Stimulus":
321                   
322                    s_type = None
323                    if n.hasAttributes():
324                        # Get stilumus ID
325                        for i in range(len(n.attributes)):
326                            attr_name = n.attributes.item(i).name
327                            if attr_name == "id":
328                                s_type = n.attributes.item(i).nodeValue
329                    if hasattr(generator,"%sStimulus" % s_type):
330                        stimulus = getattr(generator,"%sStimulus" % s_type)
331                        #print s_type, self.model.name
332                        m, res = stimulus(self.model)
333                        #print "     ", m.name
334                        self.model = m
335                        if not res.isPassed():
336                            self.passed = False
337                       
338                        report += res
339                       
340                    else:
341                        print "Stimulus %s not found" % s_type
342                       
343        self.reportCard += report
344       
345        if not self.passed:
346            print report.trace
347        return self.passed
348       
349   
350def evalStimulus(model):
351    """ Apply the stimulus to the supplied model object
352        @param model: model to apply the stimulus to
353        @return: True if passed, False otherwise
354    """
355    minval = 0
356    maxval = 20
357    # Call run with random input
358    import random, math
359    input_value = random.random()*(maxval-minval)+minval
360   
361    # Catch division by zero, which can happen
362    # when creating random models
363    try:
364        # Choose whether we want 1D or 2D
365        if random.random()>0.5:
366            output = model.run(input_value)
367        else:
368            output = model.run([input_value, 2*math.pi*random.random()])           
369    except ZeroDivisionError:
370        output = -1
371       
372    #print "Eval: %g" % output
373   
374    # Oracle bit - just check that we have a number...
375    passed = False
376    if math.fabs(output) >= 0: 
377        passed = True
378   
379    report = ReportCard()
380    report.n_eval = 1
381    if passed:
382        report.n_eval_pass = 1
383    else:
384        report.log = "Eval: bad output value %g\n" % output
385       
386    report.trace = "Eval(%g) = %g %i\n" % (input_value, output, passed)   
387    return model, report
388
389def setStimulus(model):
390    """ Apply the stimulus to the supplied model object
391        @param model: model to apply the stimulus to
392        @return: True if passed, False otherwize
393    """
394    # Call run with random input
395    import random, math
396    minval = 1
397    maxval = 50
398   
399    # Choose a parameter to change
400    keys = model.getParamList()
401    if len(keys)==0:
402        #print model.name+" has no params"
403        return model, ReportCard()
404   
405    p_len = len(keys)
406    i_param = int(math.floor(random.random()*p_len))
407    p_name  = keys[i_param]
408   
409    # Chose a value
410    input_val = random.random()*(maxval-minval)+minval
411    model.setParam(p_name, input_val)
412   
413    # Read back
414    check_val = model.getParam(p_name)
415    #print "Set: %s = %g" % (p_name, check_val)
416   
417    # Oracle bit
418    passed = False
419    if check_val == input_val: 
420        passed = True
421   
422    report = ReportCard()
423    report.n_set = 1
424    if passed:
425        report.n_set_pass = 1
426    else:
427        report.log = "Set: parameter %s not set properly %g <> %g\n" % \
428            (p_name, input_val, check_val)
429       
430    report.trace = "Set %s = %g %i\n" % (p_name, input_val, passed)   
431    return model, report
432
433def getStimulus(model):
434    """ Apply the stimulus to the supplied model object
435        @param model: model to apply the stimulus to
436        @return: True if passed, False otherwise
437    """
438    import random, math
439    # Choose a parameter to change
440    keys = model.getParamList()
441    if len(keys)==0:
442        #print model.name+" has no params"
443        return model, ReportCard()   
444   
445    p_len = len(keys)
446    i_param = int(math.floor(random.random()*p_len))
447    p_name  = keys[i_param]
448   
449    # Read back
450    check_val = model.getParam(p_name)
451    #print "Get: %s = %g" % (p_name, check_val)
452   
453    # Oracle bit
454    passed = False
455    if math.fabs(check_val) >= 0: 
456        passed = True
457       
458    report = ReportCard()
459    report.n_get = 1
460    if passed:
461        report.n_get_pass = 1
462    else:
463        report.log = "Get: bad value for parameter %s: %g\n" % \
464            (p_name, check_val)
465       
466    report.trace = "Get %s = %g %i\n" % (p_name, check_val, passed)   
467    return model, report
468
469def loadStimulus(model):
470    """ Apply the stimulus to the supplied model object
471        @param model: model to apply the stimulus to
472        @return: True if passed, False otherwize
473    """
474    from sans.models.ModelIO import ModelIO
475    from sans.models.ModelFactory import ModelFactory
476    factory = ModelFactory()
477    io = ModelIO(factory)
478   
479    # testModel.xml should be in the directory
480    loaded = io.load("load_test_model.xml")
481    value2 = loaded.run(1)
482   
483    # Oracle bit
484    passed = False
485    if not value2 == 0: 
486        passed = True
487       
488    report = ReportCard()
489    report.n_load = 1
490    if passed:
491        report.n_load_pass = 1
492    else:
493        report.log = "Load: bad loaded model\n"
494       
495    report.trace = "Load = %s %i\n" % (loaded.name, passed)   
496    return model, report
497
498def saveStimulus(model):
499    """ Apply the stimulus to the supplied model object
500        @param model: model to apply the stimulus to
501        @return: True if passed, False otherwize
502    """
503    from sans.models.ModelIO import ModelIO
504    from sans.models.ModelFactory import ModelFactory
505    factory = ModelFactory()
506    io = ModelIO(factory)
507    io.save(model,"testModel.xml")
508    #value = model.run(1)
509
510    # Check it
511    loaded = io.load("testModel.xml")
512    try:
513        value2 = loaded.run(1)
514    except ZeroDivisionError:
515        value2 = -1
516   
517    # Oracle bit
518    passed = False
519   
520    # The two outputs do not need to be the same
521    # since we do not save the parameters!
522    # If it's the subtraction of two identical models,
523    # we will also get zero!
524    #if value2 == value:
525    #    passed = True
526   
527    passed = True
528
529       
530    report = ReportCard()
531    report.n_save = 1
532    if passed:
533        report.n_save_pass = 1
534    else:
535        report.log = "Save: bad output from saved model %g\n" % value2
536       
537    report.trace = "Save %s %i\n" % (model.name, passed)   
538    return model, report
539
540def getRandomModelObject():
541    """
542        Return a random model object
543    """
544    from sans.models.ModelFactory import ModelFactory
545    while True:
546        try:
547            modelname = randomModel()
548            add_model = ModelFactory().getModel(modelname)
549            return add_model
550        except:
551            # Don't complain when a model can't be loaded.
552            # A list of tested models will appear in the
553            # file report, which can be compared with the
554            # list of available models, also printed.
555            pass
556            #print "Could not load ", modelname
557   
558
559def addStimulus(model):
560    """ Apply the stimulus to the supplied model object
561        @param model: model to apply the stimulus to
562        @return: True if passed, False otherwize
563    """
564    #from sans.models.ModelFactory import ModelFactory
565    #factory = ModelFactory()
566    #add_model = factory.getModel("SphereModel")
567    add_model = getRandomModelObject()
568           
569    tmp = model + add_model
570    model = tmp
571   
572    # Oracle bit
573    passed = False
574   
575    try:
576        value2 = model.run(1)
577        value2 = float(value2)
578    except:
579        passed = False
580   
581    # If we made it this far, we have a float
582    passed = True 
583       
584    report = ReportCard()
585    report.n_add = 1
586    if passed:
587        report.n_add_pass = 1
588    else:
589        report.log = "Add: bad output from composite model\n"
590       
591    report.trace = "Div %s %i\n" % (model.name, passed)   
592    return model, report
593
594def subStimulus(model):
595    """ Apply the stimulus to the supplied model object
596        @param model: model to apply the stimulus to
597        @return: True if passed, False otherwize
598    """
599    from sans.models.ModelFactory import ModelFactory
600    from random import random
601   
602    #factory = ModelFactory()
603    #sub_model = factory.getModel("CylinderModel")
604    #sub_model = factory.getModel(randomModel())
605    sub_model = getRandomModelObject()
606
607   
608    sub_model.params["scale"] = 1.1*random()
609    tmp = model - sub_model
610   
611    # Oracle bit
612    passed = False
613   
614    try:
615        value2 = tmp.run(1)
616        value2 = float(value2)
617    except:
618        passed = False
619   
620    # If we made it this far, we have a float
621    passed = True 
622       
623    report = ReportCard()
624    report.n_sub = 1
625    if passed:
626        report.n_sub_pass = 1
627    else:
628        report.log = "Sub: bad output from composite model\n"
629
630    report.trace = "Sub %s (%g - %g = %g) %i\n" % \
631        (model.name, model.run(1), \
632         sub_model.run(1), value2, passed)               
633    return tmp, report
634
635def mulStimulus(model):
636    """ Apply the stimulus to the supplied model object
637        @param model: model to apply the stimulus to
638        @return: True if passed, False otherwize
639    """
640    #from sans.models.ModelFactory import ModelFactory
641    #factory = ModelFactory()
642    #mul_model = factory.getModel("SphereModel")
643    #mul_model = factory.getModel(randomModel())
644    mul_model = getRandomModelObject()
645    tmp = model * mul_model
646   
647    # Oracle bit
648    passed = False
649
650    try:
651        value2 = tmp.run(1)
652        value2 = float(value2)
653    except:
654        passed = False
655   
656    # If we made it this far, we have a float
657    passed = True 
658       
659    report = ReportCard()
660    report.n_mul = 1
661    if passed:
662        report.n_mul_pass = 1
663    else:
664        report.log = "Mul: bad output from composite model\n"
665       
666    report.trace = "Mul %s (%g * %g = %g) %i\n" % \
667      (model.name, model.run(1), \
668       mul_model.run(1), value2, passed)               
669    return tmp, report
670
671def divStimulus(model):
672    """ Apply the stimulus to the supplied model object
673        @param model: model to apply the stimulus to
674        @return: True if passed, False otherwise
675    """
676    #from sans.models.ModelFactory import ModelFactory
677    #factory = ModelFactory()
678    #div_model = factory.getModel("SphereModel")
679    #div_model = factory.getModel(randomModel())
680    div_model = getRandomModelObject()
681   
682    tmp = model / div_model
683   
684    # Oracle bit
685    passed = False
686   
687    try:
688        from random import random
689        input_val = 1.5 * random()
690        if div_model.run(input_val)==0:
691            print "Skipped (DIV) because denominator evaluated to zero"
692        else:
693            value2 = tmp.run(input_val)
694            value2 = float(value2)
695    except:
696        passed = False
697   
698    # If we made it this far, we have a float
699    passed = True 
700       
701    report = ReportCard()
702    report.n_div = 1
703    if passed:
704        report.n_div_pass = 1
705    else:
706        report.log = "Div: bad output from composite model\n"
707       
708    report.trace = "Div %s (%g / %g = %g) %i\n" % \
709        (model.name, model.run(1), \
710         div_model.run(1), value2, passed)               
711    return tmp, report
712
713if __name__ == '__main__':
714
715    #print randomModel()
716    g = TestCaseGenerator()
717    g.generateAndRun(2000)
718   
719    #t = TestCase(filename = "error_1.17721e+009.xml")
720    #print t.run()
721   
Note: See TracBrowser for help on using the repository browser.