source: sasmodels/sasmodels/model_test.py @ f734e7d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since f734e7d was f734e7d, checked in by pkienzle, 9 years ago

restructure c code generation for maintainability; extend test harness to allow opencl and ctypes tests

  • Property mode set to 100644
File size: 4.9 KB
Line 
1# -*- coding: utf-8 -*-
2"""
3Created on Tue Feb 17 11:43:56 2015
4
5@author: David
6"""
7
8import sys
9import unittest
10
11import numpy as np
12
13from .core import list_models, load_model_definition
14from .core import load_model_cl, load_model_dll
15from .core import make_kernel, call_kernel
16
17def annotate_exception(exc, msg):
18    """
19    Add an annotation to the current exception, which can then be forwarded
20    to the caller using a bare "raise" statement to reraise the annotated
21    exception.
22    Example::
23        >>> D = {}
24        >>> try:
25        ...    print D['hello']
26        ... except Exception,exc:
27        ...    annotate_exception(exc, "while accessing 'D'")
28        ...    raise
29        Traceback (most recent call last):
30            ...
31        KeyError: "hello while accessing 'D'"
32    """
33    args = exc.args
34    if not args:
35        exc.args = (msg,)
36    else:
37        try:
38            arg0 = " ".join((args[0],msg))
39            exc.args = tuple([arg0] + list(args[1:]))
40        except:
41            exc.args = (" ".join((str(exc),msg)),)
42   
43def suite(loaders, models):
44
45    suite = unittest.TestSuite()
46
47    if models[0] == 'all':
48        skip = models[1:]
49        models = list_models()
50    else:
51        skip = []
52    for model_name in models:
53        if model_name in skip: continue
54        model_definition = load_model_definition(model_name)
55
56        smoke_tests = [[{},0.1,None],[{},(0.1,0.1),None]]
57        tests = smoke_tests + getattr(model_definition, 'tests', [])
58       
59        if tests: # in case there are no smoke tests...
60            #print '------'
61            #print 'found tests in', model_name
62            #print '------'
63
64            # if ispy then use the dll loader to call pykernel
65            # don't try to call cl kernel since it will not be
66            # available in some environmentes.
67            ispy = callable(getattr(model_definition,'Iq', None))
68
69            # test using opencl if desired
70            if not ispy and ('opencl' in loaders and load_model_cl):
71                test_name = "Model: %s, Kernel: OpenCL"%model_name
72                test = ModelTestCase(test_name, model_definition,
73                                     load_model_cl, tests)
74                print "defining", test_name
75                suite.addTest(test)
76
77            # test using dll if desired
78            if ispy or ('dll' in loaders and load_model_dll):
79                test_name = "Model: %s, Kernel: dll"%model_name
80                test = ModelTestCase(test_name, model_definition,
81                                     load_model_dll, tests)
82                print "defining", test_name
83                suite.addTest(test)
84
85    return suite
86
87class ModelTestCase(unittest.TestCase):
88   
89    def __init__(self, test_name, definition, loader, tests):
90        unittest.TestCase.__init__(self)
91       
92        self.test_name = test_name
93        self.definition = definition
94        self.loader = loader
95        self.tests = tests
96
97    def runTest(self):
98        print "running", self.test_name
99        try:
100            model = self.loader(self.definition)
101            for test in self.tests:
102                pars, Q, I = test
103
104                if not isinstance(Q, list):
105                    Q = [Q]
106                if not isinstance(I, list):
107                    I = [I]
108                   
109                if isinstance(Q[0], tuple):
110                    Qx,Qy = zip(*Q)
111                    Q_vectors = [np.array(Qx), np.array(Qy)]
112                else:
113                    Q_vectors = [np.array(Q)]
114
115                self.assertEqual(len(I), len(Q))
116
117                kernel = make_kernel(model, Q_vectors)
118                Iq = call_kernel(kernel, pars)
119           
120                self.assertGreater(len(Iq), 0)   
121                self.assertEqual(len(I), len(Iq))             
122               
123                for q, i, iq in zip(Q, I, Iq):
124                    if i is None: continue # smoke test --- make sure it runs
125                    err = abs(i - iq)
126                    nrm = abs(i)
127           
128                    self.assertLess(err * 10**5, nrm, 'q:%s; expected:%s; actual:%s' % (q, i, iq))
129                   
130        except Exception,exc: 
131            annotate_exception(exc, self.test_name)
132            raise
133
134def main():
135
136    models = sys.argv[1:]
137    if models and models[0] == 'opencl':
138        if load_model_cl is None:
139            print >>sys.stderr, "opencl is not available"
140            sys.exit(1)
141        loaders = ['opencl']
142        models = models[1:]
143    elif models and models[0] == 'dll':
144        # TODO: test if compiler is available?
145        loaders = ['dll']
146        models = models[1:]
147    else:
148        loaders = ['opencl', 'dll']
149    if models:
150        runner = unittest.TextTestRunner()
151        runner.run(suite(loaders, models))
152    else:
153        print >>sys.stderr, "usage: python -m sasmodels.model_test [opencl|dll] model1 model2 ..."
154        print >>sys.stderr, "if model1 is 'all', then all except the remaining models will be tested"
155
156if __name__ == "__main__":
157    main()
Note: See TracBrowser for help on using the repository browser.