source: sasview/sansmodels/test/testcase_generator.py @ e03a14e

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since e03a14e was 18e250c, checked in by Gervaise Alina <gervyh@…>, 13 years ago

move test to sansmodels top level

  • Property mode set to 100644
File size: 24.1 KB
Line 
1"""
2    Random test-case generator for BaseComponents
3   
4"""
5import time
6
7def randomModel():
8    """ Return a random model name """
9    from sans.models.ModelFactory import ModelFactory
10    from random import random
11    from math import floor
12   
13    model_list = ModelFactory().getAllModels()
14    ready = False
15    while not ready:
16        rnd_id = int(floor(random()*len(model_list)))
17        if model_list[rnd_id] not in ["TestSphere2", "DataModel"]:
18            ready = True
19    return model_list[rnd_id]
20
21class ReportCard: # pylint: disable-msg=R0902
22    """ Class to hold test-case results """
23   
24    def __init__(self):
25        """ Initialization """
26        ## Number of test cases
27        self.n_cases = 0
28        ## Number of passed test cases
29        self.n_cases_pass = 0
30        ## Number of Evaluation calls
31        self.n_eval = 0
32        ## Number of passed Evaluation calls
33        self.n_eval_pass = 0
34        ## Number of SetParam calls
35        self.n_set = 0
36        ## Number of passed Set calls
37        self.n_set_pass = 0
38        ## Number of GetParam calls
39        self.n_get = 0
40        ## Number of passed Get calls
41        self.n_get_pass = 0
42        ## Number of load calls
43        self.n_load = 0
44        ## Number of passed Load calls
45        self.n_load_pass = 0
46        ## Number of save calls
47        self.n_save = 0
48        ## Number of passed Save calls
49        self.n_save_pass = 0
50        ## Number of Add calls
51        self.n_add = 0
52        ## Number of passed Add calls
53        self.n_add_pass = 0
54        ## Number of Sub calls
55        self.n_sub = 0
56        ## Number of passed Sub calls
57        self.n_sub_pass = 0
58        ## Number of Div calls
59        self.n_div = 0
60        ## Number of passed Div calls
61        self.n_div_pass = 0
62        ## Number of mul calls
63        self.n_mul = 0
64        ## Number of passed Mul calls
65        self.n_mul_pass = 0
66        ## Test log
67        self.log = ""
68        ## Trace
69        self.trace = ""
70        ## Model tested
71        self.model = ""
72        ## List of all models tested
73        self.modelList = []
74       
75       
76    def __add__(self, other):
77        """ Add two report cards
78            @param other: other report to add
79            @return: updated self
80        """
81       
82        self.n_cases += other.n_cases
83        self.n_cases_pass += other.n_cases_pass
84        self.n_eval += other.n_eval
85        self.n_eval_pass += other.n_eval_pass
86        self.n_set += other.n_set
87        self.n_set_pass += other.n_set_pass
88        self.n_get += other.n_get
89        self.n_get_pass += other.n_get_pass
90        self.n_load += other.n_load
91        self.n_load_pass += other.n_load_pass
92        self.n_save += other.n_save
93        self.n_save_pass += other.n_save_pass
94        self.n_add += other.n_add
95        self.n_add_pass += other.n_add_pass
96        self.n_sub += other.n_sub
97        self.n_sub_pass += other.n_sub_pass
98        self.n_mul += other.n_mul
99        self.n_mul_pass += other.n_mul_pass
100        self.n_div += other.n_div
101        self.n_div_pass += other.n_div_pass
102        if len(other.log)>0:
103            self.log += other.log
104        if len(other.trace)>0:
105            self.trace += other.trace
106           
107        if not other.model in self.modelList:
108            self.modelList.append(other.model)
109           
110        return self
111       
112    def isPassed(self):
113        """ Returns true if no error was found """
114        if self.n_cases_pass < self.n_cases \
115            or self.n_save_pass < self.n_save \
116            or self.n_load_pass < self.n_load \
117            or self.n_set_pass < self.n_set \
118            or self.n_get_pass < self.n_get \
119            or self.n_eval_pass < self.n_eval \
120            or self.n_add_pass < self.n_add \
121            or self.n_sub_pass < self.n_sub \
122            or self.n_mul_pass < self.n_mul \
123            or self.n_div_pass < self.n_div:
124            return False
125        return True
126       
127    def __str__(self):
128        """ String representation of the report card """
129        from sans.models.ModelFactory import ModelFactory
130       
131        rep = "Detailed output:\n"
132        rep += self.log
133        rep += "\n"
134        rep += "Total number of cases: %g\n" % self.n_cases
135        rep += "   Passed:             %g\n" % self.n_cases_pass
136        rep += "\n"
137        self.modelList.sort()
138        rep += "Models tested: %s\n" % self.modelList
139        all_models = ModelFactory().getAllModels()
140        all_models.sort()
141        rep += "\n"
142        rep += "Available models: %s\n" % all_models
143        rep += "\n"
144        rep += "Breakdown:\n"
145        rep += "   Eval:          %g / %g\n" % (self.n_eval_pass, self.n_eval)
146        rep += "   Set:           %g / %g\n" % (self.n_set_pass, self.n_set)
147        rep += "   Get:           %g / %g\n" % (self.n_get_pass, self.n_get)
148        rep += "   Load:          %g / %g\n" % (self.n_load_pass, self.n_load)
149        rep += "   Save:          %g / %g\n" % (self.n_save_pass, self.n_save)
150        rep += "   Add:           %g / %g\n" % (self.n_add_pass, self.n_add)
151        rep += "   Sub:           %g / %g\n" % (self.n_sub_pass, self.n_sub)
152        rep += "   Mul:           %g / %g\n" % (self.n_mul_pass, self.n_mul)
153        rep += "   Div:           %g / %g\n" % (self.n_div_pass, self.n_div)
154        return rep
155   
156       
157
158class TestCaseGenerator:
159    """ Generator for suite of test-cases
160    """
161   
162    def __init__(self):
163        """ Initialization
164        """
165       
166        self.n_tests = 0
167        self.n_passed = 0
168        self.time = 0
169        self.reportCard = ReportCard()
170   
171    def generateFiles(self, number, file_prefix):
172        """ Generate test-case files
173            @param number: number of files to generate
174            @param file_prefix: prefix for the file names
175        """
176       
177        for i in range(number):
178            filename = "%s_%d.xml" % (file_prefix, i)
179            self.generateFileTest(filename)
180            self.n_tests += 1
181           
182    def generateAndRun(self, number):
183        """ Generate test-cases and run them
184            @param number: number of test-cases to generate
185        """
186        start_time = time.time()
187        for i in range(number):
188            textcase = self.generateTest()
189            t = TestCase(text = textcase)
190            passed = t.run()
191            self.reportCard += t.reportCard
192            if not passed:
193                t = time.localtime()
194                xmloutput = open("error_%i-%i-%i-%i-%i-%i_%i.xml" % \
195                 (t[0],t[1],t[2],t[3],t[4],t[5],self.reportCard.n_cases),'w')
196                xmloutput.write(textcase)
197                xmloutput.close()
198
199               
200        self.time += time.time()-start_time
201        print self.reportCard       
202       
203    def generateFileTest(self, filename = "tmp.xml"):
204        """
205            Write a random test-case in an XML file
206            @param filename: name of file to write to
207        """
208        text = self.generateTest()
209        # Write test case in file
210        xmlfile = open(filename,'w')
211        xmlfile.write(text)
212        xmlfile.close()
213   
214   
215    def generateTest(self):
216        """
217            Generate an XML representation of a random test-case
218        """
219        import random
220       
221        #t = TestCase()
222        text = "<?xml version=\"1.0\"?>\n"
223        text  += "<TestCase>\n"
224        stimulus = "eval"
225   
226        loop = True
227        while loop:
228            rnd = random.random()
229           
230            # run 50%
231            if rnd <= 0.5:
232                stimulus = "eval"
233            elif rnd>0.5 and rnd<=0.7:
234                stimulus = "set"
235            elif rnd>0.7 and rnd<=0.72:
236                stimulus = "save"
237            elif rnd>0.72 and rnd<=0.74:
238                stimulus = "load"
239            elif rnd>0.74 and rnd<=0.8:
240                stimulus = "get"
241            elif rnd>0.8 and rnd<=0.81:
242                stimulus = "add"
243            elif rnd>0.81 and rnd<=0.82:
244                stimulus = "div"
245            elif rnd>0.82 and rnd<=0.83:
246                stimulus = "mul"
247            elif rnd>0.83 and rnd<=0.84:
248                stimulus = "sub"
249            else:
250                # just run and quit
251                stimulus = "eval"
252                loop = False
253               
254            text += "  <Stimulus id=\"%s\"/>\n" % stimulus
255        text += "</TestCase>"
256       
257        return text
258   
259       
260class TestCase:
261    """ Test-case class """
262   
263    def __init__(self, filename = None, text = None):
264        """ Initialization
265            @param filename: name of file containing the test case
266        """
267        #from sans.models.ModelFactory import ModelFactory
268        self.filename = filename
269        self.text = text
270        #self.model = ModelFactory().getModel(randomModel())
271        self.model = getRandomModelObject()
272        #self.model = ModelFactory().getModel("SphereModel")
273        self.passed = True
274        self.reportCard = ReportCard()
275       
276   
277    def run(self):
278        """ Read the test case and execute it """
279        from xml.dom.minidom import parse
280        from xml.dom.minidom import parseString
281       
282        # Initialize report
283        self.reportCard = ReportCard()
284        self.reportCard.model = self.model.name
285        self.reportCard.n_cases = 1
286       
287        if not self.text == None:
288            dom = parseString(self.text)
289        elif not self.filename == None:
290            dom = parse(self.filename)
291        else:
292            print "No input to parse"
293            return False
294       
295        if dom.hasChildNodes():
296            for n in dom.childNodes:
297                if n.nodeName == "TestCase":
298                    self.processStimuli(n)
299           
300        # Update report card       
301        if self.passed:
302            self.reportCard.n_cases_pass = 1
303           
304        return self.passed
305                         
306    def processStimuli(self, node):
307        """ Process the stimuli list in the TestCase node
308            of an XML test-case file
309            @param node: test-case node
310        """
311        import testcase_generator as generator
312        report = ReportCard()
313        report.trace = "%s\n" % self.model.name
314       
315        self.passed = True
316        if node.hasChildNodes():
317            for n in node.childNodes:
318                if n.nodeName == "Stimulus":
319                   
320                    s_type = None
321                    if n.hasAttributes():
322                        # Get stilumus ID
323                        for i in range(len(n.attributes)):
324                            attr_name = n.attributes.item(i).name
325                            if attr_name == "id":
326                                s_type = n.attributes.item(i).nodeValue
327                    if hasattr(generator,"%sStimulus" % s_type):
328                        stimulus = getattr(generator,"%sStimulus" % s_type)
329                        #print s_type, self.model.name
330                        m, res = stimulus(self.model)
331                        #print "     ", m.name
332                        self.model = m
333                        if not res.isPassed():
334                            self.passed = False
335                       
336                        report += res
337                       
338                    else:
339                        print "Stimulus %s not found" % s_type
340                       
341        self.reportCard += report
342       
343        if not self.passed:
344            print "\nFailure:"
345            print report.trace
346        return self.passed
347       
348   
349def evalStimulus(model):
350    """ Apply the stimulus to the supplied model object
351        @param model: model to apply the stimulus to
352        @return: True if passed, False otherwise
353    """
354    minval = 0.001
355    maxval = 1.0
356    # Call run with random input
357    import random, math
358    input_value = random.random()*(maxval-minval)+minval
359   
360    # Catch division by zero, which can happen
361    # when creating random models
362    try:
363        # Choose whether we want 1D or 2D
364        if random.random()>0.5:
365            output = model.run(input_value)
366        else:
367            output = model.run([input_value, 2*math.pi*random.random()])           
368    except ZeroDivisionError:
369        print "Error evaluating model %s: %g" % (model.name, input_value)
370        output = None
371       
372    #print "Eval: %g" % output
373   
374    # Oracle bit - just check that we have a number...
375    passed = False
376    try:
377        if output != None and math.fabs(output) >= 0: 
378            passed = True
379    except:
380        print "---> Could not compute abs val", output, model.name
381       
382   
383    report = ReportCard()
384    report.n_eval = 1
385    if passed:
386        report.n_eval_pass = 1
387    else:
388        report.log = "Eval: bad output value %s\n" % str(output)
389       
390    report.trace = "Eval(%g) = %s %i\n" % (input_value, str(output), passed)   
391    return model, report
392
393def setStimulus(model):
394    """ Apply the stimulus to the supplied model object
395        @param model: model to apply the stimulus to
396        @return: True if passed, False otherwize
397    """
398    # Call run with random input
399    import random, math
400    minval = 1
401    maxval = 5
402   
403    # Choose a parameter to change
404    keys = model.getParamList()
405    if len(keys)==0:
406        #print model.name+" has no params"
407        return model, ReportCard()
408   
409    p_len = len(keys)
410    i_param = int(math.floor(random.random()*p_len))
411    p_name  = keys[i_param]
412   
413    # Choose a value
414    # Check for min/max
415    if hasattr(model, "details"):
416        if model.details.has_key(p_name):
417            if model.details[p_name][1] != None:
418                minval = model.details[p_name][1]
419            if model.details[p_name][2] != None:
420                maxval = model.details[p_name][2]
421        elif model.other.details.has_key(p_name):
422            if model.other.details[p_name][1] != None:
423                minval = model.other.details[p_name][1]
424            if model.other.details[p_name][2] != None:
425                maxval = model.other.details[p_name][2]
426        elif model.operateOn.details.has_key(p_name):
427            if model.operateOn.details[p_name][1] != None:
428                minval = model.operateOn.details[p_name][1]
429            if model.operateOn.details[p_name][2] != None:
430                maxval = model.operateOn.details[p_name][2]
431           
432    input_val = random.random()*(maxval-minval)+minval
433    model.setParam(p_name, input_val)
434   
435    # Read back
436    check_val = model.getParam(p_name)
437    #print "Set: %s = %g" % (p_name, check_val)
438   
439    # Oracle bit
440    passed = False
441    if check_val == input_val: 
442        passed = True
443   
444    report = ReportCard()
445    report.n_set = 1
446    if passed:
447        report.n_set_pass = 1
448    else:
449        report.log = "Set: parameter %s not set properly %g <> %g\n" % \
450            (p_name, input_val, check_val)
451       
452    report.trace = "Set %s = %g %i\n" % (p_name, input_val, passed)   
453    return model, report
454
455def getStimulus(model):
456    """ Apply the stimulus to the supplied model object
457        @param model: model to apply the stimulus to
458        @return: True if passed, False otherwise
459    """
460    import random, math
461    # Choose a parameter to change
462    keys = model.getParamList()
463    if len(keys)==0:
464        #print model.name+" has no params"
465        return model, ReportCard()   
466   
467    p_len = len(keys)
468    i_param = int(math.floor(random.random()*p_len))
469    p_name  = keys[i_param]
470   
471    # Read back
472    check_val = model.getParam(p_name)
473    #print "Get: %s = %g" % (p_name, check_val)
474   
475    # Oracle bit
476    passed = False
477    if math.fabs(check_val) >= 0: 
478        passed = True
479       
480    report = ReportCard()
481    report.n_get = 1
482    if passed:
483        report.n_get_pass = 1
484    else:
485        report.log = "Get: bad value for parameter %s: %g\n" % \
486            (p_name, check_val)
487       
488    report.trace = "Get %s = %g %i\n" % (p_name, check_val, passed)   
489    return model, report
490
491def loadStimulus(model):
492    """ Apply the stimulus to the supplied model object
493        @param model: model to apply the stimulus to
494        @return: True if passed, False otherwize
495    """
496    from sans.models.ModelIO import ModelIO
497    from sans.models.ModelFactory import ModelFactory
498    factory = ModelFactory()
499    io = ModelIO(factory)
500   
501    # testModel.xml should be in the directory
502    loaded = io.load("load_test_model.xml")
503    value2 = loaded.run(1)
504   
505    # Oracle bit
506    passed = False
507    if not value2 == 0: 
508        passed = True
509       
510    report = ReportCard()
511    report.n_load = 1
512    if passed:
513        report.n_load_pass = 1
514    else:
515        report.log = "Load: bad loaded model\n"
516       
517    report.trace = "Load = %s %i\n" % (loaded.name, passed)   
518    return model, report
519
520def saveStimulus(model):
521    """ Apply the stimulus to the supplied model object
522        @param model: model to apply the stimulus to
523        @return: True if passed, False otherwize
524    """
525    from sans.models.ModelIO import ModelIO
526    from sans.models.ModelFactory import ModelFactory
527    factory = ModelFactory()
528    io = ModelIO(factory)
529    io.save(model,"testModel.xml")
530    #value = model.run(1)
531
532    # Check it
533    loaded = io.load("testModel.xml")
534    try:
535        value2 = loaded.run(1)
536    except ZeroDivisionError:
537        value2 = -1
538   
539    # Oracle bit
540    passed = False
541   
542    # The two outputs do not need to be the same
543    # since we do not save the parameters!
544    # If it's the subtraction of two identical models,
545    # we will also get zero!
546    #if value2 == value:
547    #    passed = True
548   
549    passed = True
550
551       
552    report = ReportCard()
553    report.n_save = 1
554    if passed:
555        report.n_save_pass = 1
556    else:
557        report.log = "Save: bad output from saved model %g\n" % value2
558       
559    report.trace = "Save %s %i\n" % (model.name, passed)   
560    return model, report
561
562def getRandomModelObject():
563    """
564        Return a random model object
565    """
566    from sans.models.ModelFactory import ModelFactory
567    while True:
568        try:
569            modelname = randomModel()
570            add_model = ModelFactory().getModel(modelname)
571            return add_model
572        except:
573            # Don't complain when a model can't be loaded.
574            # A list of tested models will appear in the
575            # file report, which can be compared with the
576            # list of available models, also printed.
577            pass
578            #print "Could not load ", modelname
579   
580
581def addStimulus(model):
582    """ Apply the stimulus to the supplied model object
583        @param model: model to apply the stimulus to
584        @return: True if passed, False otherwize
585    """
586    #from sans.models.ModelFactory import ModelFactory
587    #factory = ModelFactory()
588    #add_model = factory.getModel("SphereModel")
589    add_model = getRandomModelObject()
590           
591    tmp = model + add_model
592    model = tmp
593   
594    # Oracle bit
595    passed = False
596   
597    try:
598        value2 = model.run(1)
599        value2 = float(value2)
600    except:
601        passed = False
602   
603    # If we made it this far, we have a float
604    passed = True 
605       
606    report = ReportCard()
607    report.n_add = 1
608    if passed:
609        report.n_add_pass = 1
610    else:
611        report.log = "Add: bad output from composite model\n"
612       
613    report.trace = "Div %s %i\n" % (model.name, passed)   
614    return model, report
615
616def subStimulus(model):
617    """ Apply the stimulus to the supplied model object
618        @param model: model to apply the stimulus to
619        @return: True if passed, False otherwize
620    """
621    from sans.models.ModelFactory import ModelFactory
622    from random import random
623   
624    #factory = ModelFactory()
625    #sub_model = factory.getModel("CylinderModel")
626    #sub_model = factory.getModel(randomModel())
627    sub_model = getRandomModelObject()
628
629   
630    sub_model.params["scale"] = 1.1*random()
631    tmp = model - sub_model
632   
633    # Oracle bit
634    passed = False
635   
636    try:
637        value2 = tmp.run(1.1 * (1.0 + random()))
638        value2 = float(value2)
639    except:
640        value2 = None
641        passed = False
642   
643    # If we made it this far, we have a float
644    passed = True 
645       
646    report = ReportCard()
647    report.n_sub = 1
648    if passed:
649        report.n_sub_pass = 1
650    else:
651        report.log = "Sub: bad output from composite model\n"
652
653    report.trace = "Sub %s (%g - %g = %s) %i\n" % \
654        (model.name, model.run(1), \
655         sub_model.run(1), str(value2), passed)               
656    return tmp, report
657
658def mulStimulus(model):
659    """ Apply the stimulus to the supplied model object
660        @param model: model to apply the stimulus to
661        @return: True if passed, False otherwize
662    """
663    #from sans.models.ModelFactory import ModelFactory
664    #factory = ModelFactory()
665    #mul_model = factory.getModel("SphereModel")
666    #mul_model = factory.getModel(randomModel())
667    mul_model = getRandomModelObject()
668    tmp = model * mul_model
669   
670    # Oracle bit
671    passed = False
672   
673    from random import random
674    input_val = 1.1 * (1.0 + random())
675    v1 = None
676    v2 = None
677    try:
678        v1 = mul_model.run(input_val)
679        v2 = model.run(input_val)
680        value2 = tmp.run(input_val)
681        value2 = float(value2)
682    except ZeroDivisionError:
683        value2 = None
684   
685    # If we made it this far, we have a float
686    passed = True 
687       
688    report = ReportCard()
689    report.n_mul = 1
690    if passed:
691        report.n_mul_pass = 1
692    else:
693        report.log = "Mul: bad output from composite model\n"
694       
695    report.trace = "Mul %s (%s * %s = %s) %i\n" % \
696      (model.name, str(v1), str(v2), str(value2), passed)               
697    return tmp, report
698
699def divStimulus(model):
700    """ Apply the stimulus to the supplied model object
701        @param model: model to apply the stimulus to
702        @return: True if passed, False otherwise
703    """
704    #from sans.models.ModelFactory import ModelFactory
705    #factory = ModelFactory()
706    #div_model = factory.getModel("SphereModel")
707    #div_model = factory.getModel(randomModel())
708
709    from random import random
710    input_val = 1.5 * random()
711   
712    # Find a model that will not evaluate to zero
713    # at the input value
714    try_again = True
715    while try_again:
716        div_model = getRandomModelObject()
717        try:
718            v2 = div_model.run(input_val)
719            try_again = False
720        except:
721            pass
722       
723    tmp = model / div_model
724   
725    # Oracle bit
726    passed = False
727   
728    v1 = None
729    v2 = None
730    try:
731       
732        # Check individual models against bad combinations,
733        # which happen from time to time given that all
734        # parameters are random
735        try:
736            v2 = div_model.run(input_val)
737            v1 = model.run(input_val)
738            value2 = tmp.run(input_val)
739            value2 = float(value2)
740        except ZeroDivisionError:
741            value2 = None
742    except:
743        passed = False
744   
745    # If we made it this far, we have a float
746    passed = True 
747       
748    report = ReportCard()
749    report.n_div = 1
750    if passed:
751        report.n_div_pass = 1
752    else:
753        report.log = "Div: bad output from composite model\n"
754       
755    report.trace = "Div %s/%s (%g) = %s / %s = %s %i\n" % \
756      (model.name, div_model.name, input_val, str(v1), str(v2), str(value2), passed)               
757    return tmp, report
758
759if __name__ == '__main__':
760    print "Model operations are no longer supported."
761    print "The TestCase generator for model operation can't be used with this version of sansmodels." 
762   
763   
764    #print randomModel()
765    #g = TestCaseGenerator()
766    #g.generateAndRun(20000)
767   
768    #t = TestCase(filename = "error_1.17721e+009.xml")
769    #print t.run()
770   
Note: See TracBrowser for help on using the repository browser.