source: sasmodels/sasmodels/model_test.py @ 6d6508e

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6d6508e was 6d6508e, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

refactor model_info from dictionary to class

  • Property mode set to 100644
File size: 11.1 KB
Line 
1# -*- coding: utf-8 -*-
2"""
3Run model unit tests.
4
5Usage::
6
7    python -m sasmodels.model_test [opencl|dll|opencl_and_dll] model1 model2 ...
8
9    if model1 is 'all', then all except the remaining models will be tested
10
11Each model is tested using the default parameters at q=0.1, (qx, qy)=(0.1, 0.1),
12and the ER and VR are computed.  The return values at these points are not
13considered.  The test is only to verify that the models run to completion,
14and do not produce inf or NaN.
15
16Tests are defined with the *tests* attribute in the model.py file.  *tests*
17is a list of individual tests to run, where each test consists of the
18parameter values for the test, the q-values and the expected results.  For
19the effective radius test, the q-value should be 'ER'.  For the VR test,
20the q-value should be 'VR'.  For 1-D tests, either specify the q value or
21a list of q-values, and the corresponding I(q) value, or list of I(q) values.
22
23That is::
24
25    tests = [
26        [ {parameters}, q, I(q)],
27        [ {parameters}, [q], [I(q)] ],
28        [ {parameters}, [q1, q2, ...], [I(q1), I(q2), ...]],
29
30        [ {parameters}, (qx, qy), I(qx, Iqy)],
31        [ {parameters}, [(qx1, qy1), (qx2, qy2), ...],
32                        [I(qx1, qy1), I(qx2, qy2), ...]],
33
34        [ {parameters}, 'ER', ER(pars) ],
35        [ {parameters}, 'VR', VR(pars) ],
36        ...
37    ]
38
39Parameters are *key:value* pairs, where key is one of the parameters of the
40model and value is the value to use for the test.  Any parameters not given
41in the parameter list will take on the default parameter value.
42
43Precision defaults to 5 digits (relative).
44"""
45#TODO: rename to tests so that tab completion works better for models directory
46
47from __future__ import print_function
48
49import sys
50import unittest
51
52import numpy as np
53
54from .core import list_models, load_model_info, build_model, HAVE_OPENCL
55from .details import dispersion_mesh
56from .direct_model import call_kernel, get_weights
57from .exception import annotate_exception
58
59
60def call_ER(model_info, values):
61    """
62    Call the model ER function using *values*. *model_info* is either
63    *model.info* if you have a loaded model, or *kernel.info* if you
64    have a model kernel prepared for evaluation.
65    """
66    if model_info.ER is None:
67        return 1.0
68    else:
69        vol_pars = [get_weights(parameter, values)
70                    for parameter in model_info.parameters.call_parameters
71                    if parameter.type == 'volume']
72        value, weight = dispersion_mesh(vol_pars)
73        individual_radii = model_info.ER(*value)
74        return np.sum(weight*individual_radii) / np.sum(weight)
75
76def call_VR(model_info, values):
77    """
78    Call the model VR function using *pars*.
79    *info* is either *model.info* if you have a loaded model, or *kernel.info*
80    if you have a model kernel prepared for evaluation.
81    """
82    if model_info.VR is None:
83        return 1.0
84    else:
85        vol_pars = [get_weights(parameter, values)
86                    for parameter in model_info.parameters.call_parameters
87                    if parameter.type == 'volume']
88        value, weight = dispersion_mesh(vol_pars)
89        whole, part = model_info.VR(*value)
90        return np.sum(weight*part)/np.sum(weight*whole)
91
92
93def make_suite(loaders, models):
94    """
95    Construct the pyunit test suite.
96
97    *loaders* is the list of kernel drivers to use, which is one of
98    *["dll", "opencl"]*, *["dll"]* or *["opencl"]*.  For python models,
99    the python driver is always used.
100
101    *models* is the list of models to test, or *["all"]* to test all models.
102    """
103
104    ModelTestCase = _hide_model_case_from_nosetests()
105    suite = unittest.TestSuite()
106
107    if models[0] == 'all':
108        skip = models[1:]
109        models = list_models()
110    else:
111        skip = []
112    for model_name in models:
113        if model_name in skip: continue
114        model_info = load_model_info(model_name)
115
116        #print('------')
117        #print('found tests in', model_name)
118        #print('------')
119
120        # if ispy then use the dll loader to call pykernel
121        # don't try to call cl kernel since it will not be
122        # available in some environmentes.
123        is_py = callable(model_info.Iq)
124
125        if is_py:  # kernel implemented in python
126            test_name = "Model: %s, Kernel: python"%model_name
127            test_method_name = "test_%s_python" % model_name
128            test = ModelTestCase(test_name, model_info,
129                                 test_method_name,
130                                 platform="dll",  # so that
131                                 dtype="double")
132            suite.addTest(test)
133        else:   # kernel implemented in C
134            # test using opencl if desired and available
135            if 'opencl' in loaders and HAVE_OPENCL:
136                test_name = "Model: %s, Kernel: OpenCL"%model_name
137                test_method_name = "test_%s_opencl" % model_name
138                # Using dtype=None so that the models that are only
139                # correct for double precision are not tested using
140                # single precision.  The choice is determined by the
141                # presence of *single=False* in the model file.
142                test = ModelTestCase(test_name, model_info,
143                                     test_method_name,
144                                     platform="ocl", dtype=None)
145                #print("defining", test_name)
146                suite.addTest(test)
147
148            # test using dll if desired
149            if 'dll' in loaders:
150                test_name = "Model: %s, Kernel: dll"%model_name
151                test_method_name = "test_%s_dll" % model_name
152                test = ModelTestCase(test_name, model_info,
153                                     test_method_name,
154                                     platform="dll",
155                                     dtype="double")
156                suite.addTest(test)
157
158    return suite
159
160
161def _hide_model_case_from_nosetests():
162    class ModelTestCase(unittest.TestCase):
163        """
164        Test suit for a particular model with a particular kernel driver.
165
166        The test suite runs a simple smoke test to make sure the model
167        functions, then runs the list of tests at the bottom of the model
168        description file.
169        """
170        def __init__(self, test_name, model_info, test_method_name,
171                     platform, dtype):
172            self.test_name = test_name
173            self.info = model_info
174            self.platform = platform
175            self.dtype = dtype
176
177            setattr(self, test_method_name, self._runTest)
178            unittest.TestCase.__init__(self, test_method_name)
179
180        def _runTest(self):
181            smoke_tests = [
182                [{}, 0.1, None],
183                [{}, (0.1, 0.1), None],
184                [{}, 'ER', None],
185                [{}, 'VR', None],
186                ]
187
188            tests = self.info.tests
189            try:
190                model = build_model(self.info, dtype=self.dtype,
191                                    platform=self.platform)
192                for test in smoke_tests + tests:
193                    self._run_one_test(model, test)
194
195                if not tests and self.platform == "dll":
196                    ## Uncomment the following to make forgetting the test
197                    ## values an error.  Only do so for the "dll" tests
198                    ## to reduce noise from both opencl and dll, and because
199                    ## python kernels use platform="dll".
200                    #raise Exception("No test cases provided")
201                    pass
202
203            except:
204                annotate_exception(self.test_name)
205                raise
206
207        def _run_one_test(self, model, test):
208            pars, x, y = test
209
210            if not isinstance(y, list):
211                y = [y]
212            if not isinstance(x, list):
213                x = [x]
214
215            self.assertEqual(len(y), len(x))
216
217            if x[0] == 'ER':
218                actual = [call_ER(model.info, pars)]
219            elif x[0] == 'VR':
220                actual = [call_VR(model.info, pars)]
221            elif isinstance(x[0], tuple):
222                Qx, Qy = zip(*x)
223                q_vectors = [np.array(Qx), np.array(Qy)]
224                kernel = model.make_kernel(q_vectors)
225                actual = call_kernel(kernel, pars)
226            else:
227                q_vectors = [np.array(x)]
228                kernel = model.make_kernel(q_vectors)
229                actual = call_kernel(kernel, pars)
230
231            self.assertGreater(len(actual), 0)
232            self.assertEqual(len(y), len(actual))
233
234            for xi, yi, actual_yi in zip(x, y, actual):
235                if yi is None:
236                    # smoke test --- make sure it runs and produces a value
237                    self.assertTrue(np.isfinite(actual_yi),
238                                    'invalid f(%s): %s' % (xi, actual_yi))
239                else:
240                    self.assertTrue(is_near(yi, actual_yi, 5),
241                                    'f(%s); expected:%s; actual:%s'
242                                    % (xi, yi, actual_yi))
243
244    return ModelTestCase
245
246def is_near(target, actual, digits=5):
247    """
248    Returns true if *actual* is within *digits* significant digits of *target*.
249    """
250    import math
251    shift = 10**math.ceil(math.log10(abs(target)))
252    return abs(target-actual)/shift < 1.5*10**-digits
253
254def main():
255    """
256    Run tests given is sys.argv.
257
258    Returns 0 if success or 1 if any tests fail.
259    """
260    import xmlrunner
261
262    models = sys.argv[1:]
263    if models and models[0] == '-v':
264        verbosity = 2
265        models = models[1:]
266    else:
267        verbosity = 1
268    if models and models[0] == 'opencl':
269        if not HAVE_OPENCL:
270            print("opencl is not available")
271            return 1
272        loaders = ['opencl']
273        models = models[1:]
274    elif models and models[0] == 'dll':
275        # TODO: test if compiler is available?
276        loaders = ['dll']
277        models = models[1:]
278    elif models and models[0] == 'opencl_and_dll':
279        loaders = ['opencl', 'dll']
280        models = models[1:]
281    else:
282        loaders = ['opencl', 'dll']
283    if not models:
284        print("""\
285usage:
286  python -m sasmodels.model_test [-v] [opencl|dll] model1 model2 ...
287
288If -v is included on the command line, then use verboe output.
289
290If neither opencl nor dll is specified, then models will be tested with
291both opencl and dll; the compute target is ignored for pure python models.
292
293If model1 is 'all', then all except the remaining models will be tested.
294
295""")
296
297        return 1
298
299    #runner = unittest.TextTestRunner()
300    runner = xmlrunner.XMLTestRunner(output='logs', verbosity=verbosity)
301    result = runner.run(make_suite(loaders, models))
302    return 1 if result.failures or result.errors else 0
303
304
305def model_tests():
306    """
307    Test runner visible to nosetests.
308
309    Run "nosetests sasmodels" on the command line to invoke it.
310    """
311    tests = make_suite(['opencl', 'dll'], ['all'])
312    for test_i in tests:
313        yield test_i._runTest
314
315
316if __name__ == "__main__":
317    sys.exit(main())
Note: See TracBrowser for help on using the repository browser.