source: sasmodels/sasmodels/model_test.py @ 9890053

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9890053 was 9890053, checked in by Paul Kienzle <pkienzle@…>, 9 years ago

add smoke tests for ER/VR; check that smoke test results are valid floats

  • Property mode set to 100644
File size: 7.2 KB
Line 
1# -*- coding: utf-8 -*-
2"""
3Run model unit tests.
4
5Usage::
6
7     python -m sasmodels.model_test [opencl|dll|opencl_and_dll] model1 model2 ...
8
9     if model1 is 'all', then all except the remaining models will be tested
10
11Each model is tested using the default parameters at q=0.1, (qx,qy)=(0.1,0.1),
12and the ER and VR are computed.  The return values at these points are not
13considered.  The test is only to verify that the models run to completion,
14and do not produce inf or NaN.
15
16Tests are defined with the *tests* attribute in the model.py file.  *tests*
17is a list of individual tests to run, where each test consists of the
18parameter values for the test, the q-values and the expected results.  For
19the effective radius test, the q-value should be 'ER'.  For the VR test,
20the q-value should be 'VR'.  For 1-D tests, either specify the q value or
21a list of q-values, and the corresponding I(q) value, or list of I(q) values.
22
23That is::
24
25    tests = [
26        [ {parameters}, q, I(q)],
27        [ {parameters}, [q], [I(q)] ],
28        [ {parameters}, [q1, q2, ...], [I(q1), I(q2), ...]],
29
30        [ {parameters}, (qx, qy), I(qx, Iqy)],
31        [ {parameters}, [(qx1, qy1), (qx2, qy2), ...], [I(qx1,qy1), I(qx2,qy2), ...],
32
33        [ {parameters}, 'ER', ER(pars) ],
34        [ {parameters}, 'VR', VR(pars) ],
35        ...
36    ]
37
38Parameters are *key:value* pairs, where key is one of the parameters of the
39model and value is the value to use for the test.  Any parameters not given
40in the parameter list will take on the default parameter value.
41
42Precision defaults to 5 digits (relative).
43"""
44
45import sys
46import unittest
47
48import numpy as np
49
50from .core import list_models, load_model_definition
51from .core import load_model_cl, load_model_dll
52from .core import make_kernel, call_kernel, call_ER, call_VR
53
54def annotate_exception(exc, msg):
55    """
56    Add an annotation to the current exception, which can then be forwarded
57    to the caller using a bare "raise" statement to reraise the annotated
58    exception.
59    Example::
60        >>> D = {}
61        >>> try:
62        ...    print D['hello']
63        ... except Exception,exc:
64        ...    annotate_exception(exc, "while accessing 'D'")
65        ...    raise
66        Traceback (most recent call last):
67            ...
68        KeyError: "hello while accessing 'D'"
69    """
70    args = exc.args
71    if not args:
72        exc.args = (msg,)
73    else:
74        try:
75            arg0 = " ".join((args[0],msg))
76            exc.args = tuple([arg0] + list(args[1:]))
77        except:
78            exc.args = (" ".join((str(exc),msg)),)
79   
80def suite(loaders, models):
81
82    suite = unittest.TestSuite()
83
84    if models[0] == 'all':
85        skip = models[1:]
86        models = list_models()
87    else:
88        skip = []
89    for model_name in models:
90        if model_name in skip: continue
91        model_definition = load_model_definition(model_name)
92
93        smoke_tests = [
94            [{},0.1,None],
95            [{},(0.1,0.1),None],
96            [{},'ER',None],
97            [{},'VR',None],
98            ]
99        tests = smoke_tests + getattr(model_definition, 'tests', [])
100       
101        if tests: # in case there are no smoke tests...
102            #print '------'
103            #print 'found tests in', model_name
104            #print '------'
105
106            # if ispy then use the dll loader to call pykernel
107            # don't try to call cl kernel since it will not be
108            # available in some environmentes.
109            ispy = callable(getattr(model_definition,'Iq', None))
110
111            # test using opencl if desired
112            if not ispy and ('opencl' in loaders and load_model_cl):
113                test_name = "Model: %s, Kernel: OpenCL"%model_name
114                test = ModelTestCase(test_name, model_definition,
115                                     load_model_cl, tests)
116                #print "defining", test_name
117                suite.addTest(test)
118
119            # test using dll if desired
120            if ispy or ('dll' in loaders and load_model_dll):
121                test_name = "Model: %s, Kernel: dll"%model_name
122                test = ModelTestCase(test_name, model_definition,
123                                     load_model_dll, tests)
124                #print "defining", test_name
125                suite.addTest(test)
126
127    return suite
128
129class ModelTestCase(unittest.TestCase):
130   
131    def __init__(self, test_name, definition, loader, tests):
132        unittest.TestCase.__init__(self)
133       
134        self.test_name = test_name
135        self.definition = definition
136        self.loader = loader
137        self.tests = tests
138
139    def runTest(self):
140        #print "running", self.test_name
141        try:
142            model = self.loader(self.definition)
143            for test in self.tests:
144                pars, Q, I = test
145
146                if not isinstance(I, list):
147                    I = [I]
148                if not isinstance(Q, list):
149                    Q = [Q]
150
151                self.assertEqual(len(I), len(Q))
152
153                if Q[0] == 'ER':
154                    Iq = [call_ER(kernel, pars)]
155                elif Q[0] == 'VR':
156                    Iq = [call_VR(kernel, pars)]
157                elif isinstance(Q[0], tuple):
158                    Qx,Qy = zip(*Q)
159                    Q_vectors = [np.array(Qx), np.array(Qy)]
160                    kernel = make_kernel(model, Q_vectors)
161                    Iq = call_kernel(kernel, pars)
162                else:
163                    Q_vectors = [np.array(Q)]
164                    kernel = make_kernel(model, Q_vectors)
165                    Iq = call_kernel(kernel, pars)
166           
167                self.assertGreater(len(Iq), 0)   
168                self.assertEqual(len(I), len(Iq))             
169               
170                for q, i, iq in zip(Q, I, Iq):
171                    if i is None:
172                        # smoke test --- make sure it runs and produces a value
173                        self.assertTrue(np.isfinite(iq), 'q:%s; not finite; actual:%s' % (q, iq))
174                    else:
175                        err = abs(i - iq)
176                        nrm = abs(i)
177                        self.assertLess(err * 10**5, nrm, 'q:%s; expected:%s; actual:%s' % (q, i, iq))
178                   
179        except Exception,exc: 
180            annotate_exception(exc, self.test_name)
181            raise
182
183def main():
184
185    models = sys.argv[1:]
186    if models and models[0] == 'opencl':
187        if load_model_cl is None:
188            print >>sys.stderr, "opencl is not available"
189            sys.exit(1)
190        loaders = ['opencl']
191        models = models[1:]
192    elif models and models[0] == 'dll':
193        # TODO: test if compiler is available?
194        loaders = ['dll']
195        models = models[1:]
196    elif models and models[0] == 'opencl_and_dll':
197        if load_model_cl is None:
198            print >>sys.stderr, "opencl is not available"
199            sys.exit(1)
200        loaders = ['opencl', 'dll']
201        models = models[1:]
202    else:
203        loaders = ['opencl', 'dll']
204    if models:
205        runner = unittest.TextTestRunner()
206        runner.run(suite(loaders, models))
207    else:
208        print >>sys.stderr, "usage: python -m sasmodels.model_test [opencl|dll|opencl_and_dll] model1 model2 ..."
209        print >>sys.stderr, "if model1 is 'all', then all except the remaining models will be tested"
210
211if __name__ == "__main__":
212    main()
Note: See TracBrowser for help on using the repository browser.