source: sasview/sansmodels/src/sans/models/test/testcase_generator.py @ e71440c

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e71440c was 36948c92, checked in by Mathieu Doucet <doucetm@…>, 17 years ago

Modified to testing and modified according to testing results

  • Property mode set to 100644
File size: 24.0 KB
RevLine 
[ae3ce4e]1"""
2    Random test-case generator for BaseComponents
3   
4    @author: Mathieu Doucet / UTK
5    @license: This software is provided as part of the DANSE project.
6"""
7import time
8
9def randomModel():
10    """ Return a random model name """
11    from sans.models.ModelFactory import ModelFactory
12    from random import random
13    from math import floor
14   
15    model_list = ModelFactory().getAllModels()
16    ready = False
17    while not ready:
18        rnd_id = int(floor(random()*len(model_list)))
19        if model_list[rnd_id] not in ["TestSphere2", "DataModel"]:
20            ready = True
21    return model_list[rnd_id]
22
23class ReportCard: # pylint: disable-msg=R0902
24    """ Class to hold test-case results """
25   
26    def __init__(self):
27        """ Initialization """
28        ## Number of test cases
29        self.n_cases = 0
30        ## Number of passed test cases
31        self.n_cases_pass = 0
32        ## Number of Evaluation calls
33        self.n_eval = 0
34        ## Number of passed Evaluation calls
35        self.n_eval_pass = 0
36        ## Number of SetParam calls
37        self.n_set = 0
38        ## Number of passed Set calls
39        self.n_set_pass = 0
40        ## Number of GetParam calls
41        self.n_get = 0
42        ## Number of passed Get calls
43        self.n_get_pass = 0
44        ## Number of load calls
45        self.n_load = 0
46        ## Number of passed Load calls
47        self.n_load_pass = 0
48        ## Number of save calls
49        self.n_save = 0
50        ## Number of passed Save calls
51        self.n_save_pass = 0
52        ## Number of Add calls
53        self.n_add = 0
54        ## Number of passed Add calls
55        self.n_add_pass = 0
56        ## Number of Sub calls
57        self.n_sub = 0
58        ## Number of passed Sub calls
59        self.n_sub_pass = 0
60        ## Number of Div calls
61        self.n_div = 0
62        ## Number of passed Div calls
63        self.n_div_pass = 0
64        ## Number of mul calls
65        self.n_mul = 0
66        ## Number of passed Mul calls
67        self.n_mul_pass = 0
68        ## Test log
69        self.log = ""
70        ## Trace
71        self.trace = ""
72        ## Model tested
73        self.model = ""
74        ## List of all models tested
75        self.modelList = []
76       
77       
78    def __add__(self, other):
79        """ Add two report cards
80            @param other: other report to add
81            @return: updated self
82        """
83       
84        self.n_cases += other.n_cases
85        self.n_cases_pass += other.n_cases_pass
86        self.n_eval += other.n_eval
87        self.n_eval_pass += other.n_eval_pass
88        self.n_set += other.n_set
89        self.n_set_pass += other.n_set_pass
90        self.n_get += other.n_get
91        self.n_get_pass += other.n_get_pass
92        self.n_load += other.n_load
93        self.n_load_pass += other.n_load_pass
94        self.n_save += other.n_save
95        self.n_save_pass += other.n_save_pass
96        self.n_add += other.n_add
97        self.n_add_pass += other.n_add_pass
98        self.n_sub += other.n_sub
99        self.n_sub_pass += other.n_sub_pass
100        self.n_mul += other.n_mul
101        self.n_mul_pass += other.n_mul_pass
102        self.n_div += other.n_div
103        self.n_div_pass += other.n_div_pass
104        if len(other.log)>0:
105            self.log += other.log
106        if len(other.trace)>0:
107            self.trace += other.trace
108           
109        if not other.model in self.modelList:
110            self.modelList.append(other.model)
111           
112        return self
113       
114    def isPassed(self):
115        """ Returns true if no error was found """
116        if self.n_cases_pass < self.n_cases \
117            or self.n_save_pass < self.n_save \
118            or self.n_load_pass < self.n_load \
119            or self.n_set_pass < self.n_set \
120            or self.n_get_pass < self.n_get \
121            or self.n_eval_pass < self.n_eval \
122            or self.n_add_pass < self.n_add \
123            or self.n_sub_pass < self.n_sub \
124            or self.n_mul_pass < self.n_mul \
125            or self.n_div_pass < self.n_div:
126            return False
127        return True
128       
129    def __str__(self):
130        """ String representation of the report card """
131        from sans.models.ModelFactory import ModelFactory
132       
133        rep = "Detailed output:\n"
134        rep += self.log
135        rep += "\n"
136        rep += "Total number of cases: %g\n" % self.n_cases
137        rep += "   Passed:             %g\n" % self.n_cases_pass
138        rep += "\n"
139        self.modelList.sort()
140        rep += "Models tested: %s\n" % self.modelList
141        all_models = ModelFactory().getAllModels()
142        all_models.sort()
143        rep += "\n"
144        rep += "Available models: %s\n" % all_models
145        rep += "\n"
146        rep += "Breakdown:\n"
147        rep += "   Eval:          %g / %g\n" % (self.n_eval_pass, self.n_eval)
148        rep += "   Set:           %g / %g\n" % (self.n_set_pass, self.n_set)
149        rep += "   Get:           %g / %g\n" % (self.n_get_pass, self.n_get)
150        rep += "   Load:          %g / %g\n" % (self.n_load_pass, self.n_load)
151        rep += "   Save:          %g / %g\n" % (self.n_save_pass, self.n_save)
152        rep += "   Add:           %g / %g\n" % (self.n_add_pass, self.n_add)
153        rep += "   Sub:           %g / %g\n" % (self.n_sub_pass, self.n_sub)
154        rep += "   Mul:           %g / %g\n" % (self.n_mul_pass, self.n_mul)
155        rep += "   Div:           %g / %g\n" % (self.n_div_pass, self.n_div)
156        return rep
157   
158       
159
160class TestCaseGenerator:
161    """ Generator for suite of test-cases
162    """
163   
164    def __init__(self):
165        """ Initialization
166        """
167       
168        self.n_tests = 0
169        self.n_passed = 0
170        self.time = 0
171        self.reportCard = ReportCard()
172   
173    def generateFiles(self, number, file_prefix):
174        """ Generate test-case files
175            @param number: number of files to generate
176            @param file_prefix: prefix for the file names
177        """
178       
179        for i in range(number):
180            filename = "%s_%d.xml" % (file_prefix, i)
181            self.generateFileTest(filename)
182            self.n_tests += 1
183           
184    def generateAndRun(self, number):
185        """ Generate test-cases and run them
186            @param number: number of test-cases to generate
187        """
188        start_time = time.time()
189        for i in range(number):
190            textcase = self.generateTest()
191            t = TestCase(text = textcase)
192            passed = t.run()
193            self.reportCard += t.reportCard
194            if not passed:
195                t = time.localtime()
196                xmloutput = open("error_%i-%i-%i-%i-%i-%i_%i.xml" % \
197                 (t[0],t[1],t[2],t[3],t[4],t[5],self.reportCard.n_cases),'w')
198                xmloutput.write(textcase)
199                xmloutput.close()
200
201               
202        self.time += time.time()-start_time
203        print self.reportCard       
204       
205    def generateFileTest(self, filename = "tmp.xml"):
206        """
207            Write a random test-case in an XML file
208            @param filename: name of file to write to
209        """
210        text = self.generateTest()
211        # Write test case in file
212        xmlfile = open(filename,'w')
213        xmlfile.write(text)
214        xmlfile.close()
215   
216   
217    def generateTest(self):
218        """
219            Generate an XML representation of a random test-case
220        """
221        import random
222       
223        #t = TestCase()
224        text = "<?xml version=\"1.0\"?>\n"
225        text  += "<TestCase>\n"
226        stimulus = "eval"
227   
228        loop = True
229        while loop:
230            rnd = random.random()
231           
232            # run 50%
233            if rnd <= 0.5:
234                stimulus = "eval"
235            elif rnd>0.5 and rnd<=0.7:
236                stimulus = "set"
237            elif rnd>0.7 and rnd<=0.72:
238                stimulus = "save"
239            elif rnd>0.72 and rnd<=0.74:
240                stimulus = "load"
241            elif rnd>0.74 and rnd<=0.8:
242                stimulus = "get"
243            elif rnd>0.8 and rnd<=0.81:
244                stimulus = "add"
245            elif rnd>0.81 and rnd<=0.82:
246                stimulus = "div"
247            elif rnd>0.82 and rnd<=0.83:
248                stimulus = "mul"
249            elif rnd>0.83 and rnd<=0.84:
250                stimulus = "sub"
251            else:
252                # just run and quit
253                stimulus = "eval"
254                loop = False
255               
256            text += "  <Stimulus id=\"%s\"/>\n" % stimulus
257        text += "</TestCase>"
258       
259        return text
260   
261       
262class TestCase:
263    """ Test-case class """
264   
265    def __init__(self, filename = None, text = None):
266        """ Initialization
267            @param filename: name of file containing the test case
268        """
269        #from sans.models.ModelFactory import ModelFactory
270        self.filename = filename
271        self.text = text
272        #self.model = ModelFactory().getModel(randomModel())
273        self.model = getRandomModelObject()
274        #self.model = ModelFactory().getModel("SphereModel")
275        self.passed = True
276        self.reportCard = ReportCard()
277       
278   
279    def run(self):
280        """ Read the test case and execute it """
281        from xml.dom.minidom import parse
282        from xml.dom.minidom import parseString
283       
284        # Initialize report
285        self.reportCard = ReportCard()
286        self.reportCard.model = self.model.name
287        self.reportCard.n_cases = 1
288       
289        if not self.text == None:
290            dom = parseString(self.text)
291        elif not self.filename == None:
292            dom = parse(self.filename)
293        else:
294            print "No input to parse"
295            return False
296       
297        if dom.hasChildNodes():
298            for n in dom.childNodes:
299                if n.nodeName == "TestCase":
300                    self.processStimuli(n)
301           
302        # Update report card       
303        if self.passed:
304            self.reportCard.n_cases_pass = 1
305           
306        return self.passed
307                         
308    def processStimuli(self, node):
309        """ Process the stimuli list in the TestCase node
310            of an XML test-case file
311            @param node: test-case node
312        """
313        import testcase_generator as generator
314        report = ReportCard()
315        report.trace = "%s\n" % self.model.name
316       
317        self.passed = True
318        if node.hasChildNodes():
319            for n in node.childNodes:
320                if n.nodeName == "Stimulus":
321                   
322                    s_type = None
323                    if n.hasAttributes():
324                        # Get stilumus ID
325                        for i in range(len(n.attributes)):
326                            attr_name = n.attributes.item(i).name
327                            if attr_name == "id":
328                                s_type = n.attributes.item(i).nodeValue
329                    if hasattr(generator,"%sStimulus" % s_type):
330                        stimulus = getattr(generator,"%sStimulus" % s_type)
331                        #print s_type, self.model.name
332                        m, res = stimulus(self.model)
333                        #print "     ", m.name
334                        self.model = m
335                        if not res.isPassed():
336                            self.passed = False
337                       
338                        report += res
339                       
340                    else:
341                        print "Stimulus %s not found" % s_type
342                       
343        self.reportCard += report
344       
345        if not self.passed:
[36948c92]346            print "\nFailure:"
[ae3ce4e]347            print report.trace
348        return self.passed
349       
350   
351def evalStimulus(model):
352    """ Apply the stimulus to the supplied model object
353        @param model: model to apply the stimulus to
354        @return: True if passed, False otherwise
355    """
[36948c92]356    minval = 0.001
357    maxval = 1.0
[ae3ce4e]358    # Call run with random input
359    import random, math
360    input_value = random.random()*(maxval-minval)+minval
361   
362    # Catch division by zero, which can happen
363    # when creating random models
364    try:
365        # Choose whether we want 1D or 2D
366        if random.random()>0.5:
367            output = model.run(input_value)
368        else:
369            output = model.run([input_value, 2*math.pi*random.random()])           
370    except ZeroDivisionError:
[36948c92]371        print "Error evaluating model %s: %g" % (model.name, input_value)
372        output = None
[ae3ce4e]373       
374    #print "Eval: %g" % output
375   
376    # Oracle bit - just check that we have a number...
377    passed = False
[36948c92]378    try:
379        if output != None and math.fabs(output) >= 0: 
380            passed = True
381    except:
382        print "---> Could not compute abs val", output, model.name
383       
[ae3ce4e]384   
385    report = ReportCard()
386    report.n_eval = 1
387    if passed:
388        report.n_eval_pass = 1
389    else:
[36948c92]390        report.log = "Eval: bad output value %s\n" % str(output)
[ae3ce4e]391       
[36948c92]392    report.trace = "Eval(%g) = %s %i\n" % (input_value, str(output), passed)   
[ae3ce4e]393    return model, report
394
395def setStimulus(model):
396    """ Apply the stimulus to the supplied model object
397        @param model: model to apply the stimulus to
398        @return: True if passed, False otherwize
399    """
400    # Call run with random input
401    import random, math
402    minval = 1
[36948c92]403    maxval = 5
[ae3ce4e]404   
405    # Choose a parameter to change
406    keys = model.getParamList()
407    if len(keys)==0:
408        #print model.name+" has no params"
409        return model, ReportCard()
410   
411    p_len = len(keys)
412    i_param = int(math.floor(random.random()*p_len))
413    p_name  = keys[i_param]
414   
[36948c92]415    # Choose a value
416    # Check for min/max
417    if hasattr(model, "details"):
418        if model.details.has_key(p_name):
419            if model.details[p_name][1] != None:
420                minval = model.details[p_name][1]
421            if model.details[p_name][2] != None:
422                maxval = model.details[p_name][2]
423        elif model.other.details.has_key(p_name):
424            if model.other.details[p_name][1] != None:
425                minval = model.other.details[p_name][1]
426            if model.other.details[p_name][2] != None:
427                maxval = model.other.details[p_name][2]
428        elif model.operateOn.details.has_key(p_name):
429            if model.operateOn.details[p_name][1] != None:
430                minval = model.operateOn.details[p_name][1]
431            if model.operateOn.details[p_name][2] != None:
432                maxval = model.operateOn.details[p_name][2]
433           
[ae3ce4e]434    input_val = random.random()*(maxval-minval)+minval
435    model.setParam(p_name, input_val)
436   
437    # Read back
438    check_val = model.getParam(p_name)
439    #print "Set: %s = %g" % (p_name, check_val)
440   
441    # Oracle bit
442    passed = False
443    if check_val == input_val: 
444        passed = True
445   
446    report = ReportCard()
447    report.n_set = 1
448    if passed:
449        report.n_set_pass = 1
450    else:
451        report.log = "Set: parameter %s not set properly %g <> %g\n" % \
452            (p_name, input_val, check_val)
453       
454    report.trace = "Set %s = %g %i\n" % (p_name, input_val, passed)   
455    return model, report
456
457def getStimulus(model):
458    """ Apply the stimulus to the supplied model object
459        @param model: model to apply the stimulus to
460        @return: True if passed, False otherwise
461    """
462    import random, math
463    # Choose a parameter to change
464    keys = model.getParamList()
465    if len(keys)==0:
466        #print model.name+" has no params"
467        return model, ReportCard()   
468   
469    p_len = len(keys)
470    i_param = int(math.floor(random.random()*p_len))
471    p_name  = keys[i_param]
472   
473    # Read back
474    check_val = model.getParam(p_name)
475    #print "Get: %s = %g" % (p_name, check_val)
476   
477    # Oracle bit
478    passed = False
479    if math.fabs(check_val) >= 0: 
480        passed = True
481       
482    report = ReportCard()
483    report.n_get = 1
484    if passed:
485        report.n_get_pass = 1
486    else:
487        report.log = "Get: bad value for parameter %s: %g\n" % \
488            (p_name, check_val)
489       
490    report.trace = "Get %s = %g %i\n" % (p_name, check_val, passed)   
491    return model, report
492
493def loadStimulus(model):
494    """ Apply the stimulus to the supplied model object
495        @param model: model to apply the stimulus to
496        @return: True if passed, False otherwize
497    """
498    from sans.models.ModelIO import ModelIO
499    from sans.models.ModelFactory import ModelFactory
500    factory = ModelFactory()
501    io = ModelIO(factory)
502   
503    # testModel.xml should be in the directory
504    loaded = io.load("load_test_model.xml")
505    value2 = loaded.run(1)
506   
507    # Oracle bit
508    passed = False
509    if not value2 == 0: 
510        passed = True
511       
512    report = ReportCard()
513    report.n_load = 1
514    if passed:
515        report.n_load_pass = 1
516    else:
517        report.log = "Load: bad loaded model\n"
518       
519    report.trace = "Load = %s %i\n" % (loaded.name, passed)   
520    return model, report
521
522def saveStimulus(model):
523    """ Apply the stimulus to the supplied model object
524        @param model: model to apply the stimulus to
525        @return: True if passed, False otherwize
526    """
527    from sans.models.ModelIO import ModelIO
528    from sans.models.ModelFactory import ModelFactory
529    factory = ModelFactory()
530    io = ModelIO(factory)
531    io.save(model,"testModel.xml")
532    #value = model.run(1)
533
534    # Check it
535    loaded = io.load("testModel.xml")
536    try:
537        value2 = loaded.run(1)
538    except ZeroDivisionError:
539        value2 = -1
540   
541    # Oracle bit
542    passed = False
543   
544    # The two outputs do not need to be the same
545    # since we do not save the parameters!
546    # If it's the subtraction of two identical models,
547    # we will also get zero!
548    #if value2 == value:
549    #    passed = True
550   
551    passed = True
552
553       
554    report = ReportCard()
555    report.n_save = 1
556    if passed:
557        report.n_save_pass = 1
558    else:
559        report.log = "Save: bad output from saved model %g\n" % value2
560       
561    report.trace = "Save %s %i\n" % (model.name, passed)   
562    return model, report
563
564def getRandomModelObject():
565    """
566        Return a random model object
567    """
568    from sans.models.ModelFactory import ModelFactory
569    while True:
570        try:
571            modelname = randomModel()
572            add_model = ModelFactory().getModel(modelname)
573            return add_model
574        except:
575            # Don't complain when a model can't be loaded.
576            # A list of tested models will appear in the
577            # file report, which can be compared with the
578            # list of available models, also printed.
579            pass
580            #print "Could not load ", modelname
581   
582
583def addStimulus(model):
584    """ Apply the stimulus to the supplied model object
585        @param model: model to apply the stimulus to
586        @return: True if passed, False otherwize
587    """
588    #from sans.models.ModelFactory import ModelFactory
589    #factory = ModelFactory()
590    #add_model = factory.getModel("SphereModel")
591    add_model = getRandomModelObject()
592           
593    tmp = model + add_model
594    model = tmp
595   
596    # Oracle bit
597    passed = False
598   
599    try:
600        value2 = model.run(1)
601        value2 = float(value2)
602    except:
603        passed = False
604   
605    # If we made it this far, we have a float
606    passed = True 
607       
608    report = ReportCard()
609    report.n_add = 1
610    if passed:
611        report.n_add_pass = 1
612    else:
613        report.log = "Add: bad output from composite model\n"
614       
615    report.trace = "Div %s %i\n" % (model.name, passed)   
616    return model, report
617
618def subStimulus(model):
619    """ Apply the stimulus to the supplied model object
620        @param model: model to apply the stimulus to
621        @return: True if passed, False otherwize
622    """
623    from sans.models.ModelFactory import ModelFactory
624    from random import random
625   
626    #factory = ModelFactory()
627    #sub_model = factory.getModel("CylinderModel")
628    #sub_model = factory.getModel(randomModel())
629    sub_model = getRandomModelObject()
630
631   
632    sub_model.params["scale"] = 1.1*random()
633    tmp = model - sub_model
634   
635    # Oracle bit
636    passed = False
637   
638    try:
[36948c92]639        value2 = tmp.run(1.1 * (1.0 + random()))
[ae3ce4e]640        value2 = float(value2)
641    except:
[36948c92]642        value2 = None
[ae3ce4e]643        passed = False
644   
645    # If we made it this far, we have a float
646    passed = True 
647       
648    report = ReportCard()
649    report.n_sub = 1
650    if passed:
651        report.n_sub_pass = 1
652    else:
653        report.log = "Sub: bad output from composite model\n"
654
[36948c92]655    report.trace = "Sub %s (%g - %g = %s) %i\n" % \
[ae3ce4e]656        (model.name, model.run(1), \
[36948c92]657         sub_model.run(1), str(value2), passed)               
[ae3ce4e]658    return tmp, report
659
660def mulStimulus(model):
661    """ Apply the stimulus to the supplied model object
662        @param model: model to apply the stimulus to
663        @return: True if passed, False otherwize
664    """
665    #from sans.models.ModelFactory import ModelFactory
666    #factory = ModelFactory()
667    #mul_model = factory.getModel("SphereModel")
668    #mul_model = factory.getModel(randomModel())
669    mul_model = getRandomModelObject()
670    tmp = model * mul_model
671   
672    # Oracle bit
673    passed = False
[36948c92]674   
675    from random import random
676    input_val = 1.1 * (1.0 + random())
677    v1 = None
678    v2 = None
[ae3ce4e]679    try:
[36948c92]680        v1 = mul_model.run(input_val)
681        v2 = model.run(input_val)
682        value2 = tmp.run(input_val)
[ae3ce4e]683        value2 = float(value2)
[36948c92]684    except ZeroDivisionError:
685        value2 = None
[ae3ce4e]686   
687    # If we made it this far, we have a float
688    passed = True 
689       
690    report = ReportCard()
691    report.n_mul = 1
692    if passed:
693        report.n_mul_pass = 1
694    else:
695        report.log = "Mul: bad output from composite model\n"
696       
[36948c92]697    report.trace = "Mul %s (%s * %s = %s) %i\n" % \
698      (model.name, str(v1), str(v2), str(value2), passed)               
[ae3ce4e]699    return tmp, report
700
701def divStimulus(model):
702    """ Apply the stimulus to the supplied model object
703        @param model: model to apply the stimulus to
704        @return: True if passed, False otherwise
705    """
706    #from sans.models.ModelFactory import ModelFactory
707    #factory = ModelFactory()
708    #div_model = factory.getModel("SphereModel")
709    #div_model = factory.getModel(randomModel())
[36948c92]710
711    from random import random
712    input_val = 1.5 * random()
[ae3ce4e]713   
[36948c92]714    # Find a model that will not evaluate to zero
715    # at the input value
716    try_again = True
717    while try_again:
718        div_model = getRandomModelObject()
719        try:
720            v2 = div_model.run(input_val)
721            try_again = False
722        except:
723            pass
724       
[ae3ce4e]725    tmp = model / div_model
726   
727    # Oracle bit
728    passed = False
729   
[36948c92]730    v1 = None
731    v2 = None
[ae3ce4e]732    try:
[36948c92]733       
734        # Check individual models against bad combinations,
735        # which happen from time to time given that all
736        # parameters are random
737        try:
738            v2 = div_model.run(input_val)
739            v1 = model.run(input_val)
[ae3ce4e]740            value2 = tmp.run(input_val)
741            value2 = float(value2)
[36948c92]742        except ZeroDivisionError:
743            value2 = None
[ae3ce4e]744    except:
745        passed = False
746   
747    # If we made it this far, we have a float
748    passed = True 
749       
750    report = ReportCard()
751    report.n_div = 1
752    if passed:
753        report.n_div_pass = 1
754    else:
755        report.log = "Div: bad output from composite model\n"
756       
[36948c92]757    report.trace = "Div %s/%s (%g) = %s / %s = %s %i\n" % \
758      (model.name, div_model.name, input_val, str(v1), str(v2), str(value2), passed)               
[ae3ce4e]759    return tmp, report
760
761if __name__ == '__main__':
762
763    #print randomModel()
764    g = TestCaseGenerator()
[36948c92]765    g.generateAndRun(20000)
[ae3ce4e]766   
767    #t = TestCase(filename = "error_1.17721e+009.xml")
768    #print t.run()
769   
Note: See TracBrowser for help on using the repository browser.