source: sasmodels/sasmodels/model_test.py @ 5428233

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5428233 was 5428233, checked in by David Mannicke <davidm@…>, 9 years ago

model_test.py was added to simply testing of models (see end of cylinder.py for more details)

  • Property mode set to 100644
File size: 4.4 KB
Line 
1# -*- coding: utf-8 -*-
2"""
3Created on Tue Feb 17 11:43:56 2015
4
5@author: David
6"""
7
8import unittest
9import warnings
10import numpy as np
11
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15try:
16    from .kernelcl import load_model
17except ImportError,exc:
18    warnings.warn(str(exc))
19    warnings.warn("using ctypes instead")
20    from .kerneldll import load_model
21
22def load_kernel(model, dtype='single'):   
23    kernel = load_model(model, dtype=dtype)
24    kernel.info['defaults'] = dict((p[0],p[2]) for p in kernel.info['parameters'])
25    return kernel
26
27def get_weights(model, pars, name):
28    from . import weights
29   
30    relative = name in model.info['partype']['pd-rel']
31    disperser = pars.get(name+"_pd_type", "gaussian")
32    value = pars.get(name, model.info['defaults'][name])
33    width = pars.get(name+"_pd", 0.0)
34    npts = pars.get(name+"_pd_n", 30)
35    nsigma = pars.get(name+"_pd_nsigma", 3.0)
36    v,w = weights.get_weights(
37            disperser, npts, width, nsigma,
38            value, model.info['limits'][name], relative)
39    return v,w/np.sum(w)
40
41def eval_kernel(kernel, q, pars, cutoff=1e-5):
42    input = kernel.make_input(q)
43    finput = kernel(input)
44
45    fixed_pars = [pars.get(name, finput.info['defaults'][name])
46                  for name in finput.fixed_pars]
47    pd_pars = [get_weights(finput, pars, p) for p in finput.pd_pars]
48    return finput(fixed_pars, pd_pars, cutoff)
49
50def annotate_exception(exc, msg):
51    """
52    Add an annotation to the current exception, which can then be forwarded
53    to the caller using a bare "raise" statement to reraise the annotated
54    exception.
55    Example::
56        >>> D = {}
57        >>> try:
58        ...    print D['hello']
59        ... except Exception,exc:
60        ...    annotate_exception(exc, "while accessing 'D'")
61        ...    raise
62        Traceback (most recent call last):
63            ...
64        KeyError: "hello while accessing 'D'"
65    """
66    args = exc.args
67    if not args:
68        arg0 = msg
69    else:
70        arg0 = " ".join((args[0],msg))
71    exc.args = tuple([arg0] + list(args[1:]))
72   
73def suite():
74    root = dirname(__file__)
75    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
76    models_names = [basename(f)[:-3] for f in files]
77   
78    suite = unittest.TestSuite()
79   
80    for model_name in models_names:
81        module = __import__('sasmodels.models.' + model_name)
82        module = getattr(module, 'models', None)
83
84        model = getattr(module, model_name, None)
85        tests = getattr(model, 'tests', [])
86       
87        if tests:
88            #print '------'
89            #print 'found tests in', model_name
90            #print '------'
91   
92            kernel = load_kernel(model)
93            suite.addTest(ModelTestCase(model_name, kernel, tests))
94
95    return suite
96
97class ModelTestCase(unittest.TestCase):
98   
99    def __init__(self, model_name, kernel, tests):
100        unittest.TestCase.__init__(self)
101       
102        self.model_name = model_name
103        self.kernel = kernel
104        self.tests = tests
105
106    def runTest(self):
107        #print '------'
108        #print self.model_name
109        #print '------'
110        try:
111            for test in self.tests:
112                params = test[0]
113                Q = test[1]
114                I = test[2]
115                     
116                if not isinstance(Q, list):
117                    Q = [Q]
118                if not isinstance(I, list):
119                    I = [I]
120                   
121                if isinstance(Q[0], tuple):
122                    npQ = [np.array([Qi[d] for Qi in Q]) for d in xrange(len(Q[0]))]
123                else:
124                    npQ = [np.array(Q)]
125
126                self.assertTrue(Q)
127                self.assertEqual(len(I), len(Q))   
128           
129                Iq = eval_kernel(self.kernel, npQ, params)
130           
131                self.assertGreater(len(Iq), 0)   
132                self.assertEqual(len(I), len(Iq))             
133               
134                for q, i, iq in zip(Q, I, Iq):
135                    err = np.abs(i - iq)
136                    nrm = np.abs(i)
137           
138                    self.assertLess(err * 10**5, nrm, 'q:%s; expected:%s; actual:%s' % (q, i, iq))
139                   
140        except Exception,exc: 
141            annotate_exception(exc, '\r\nModel: %s' % self.model_name)
142            raise
143
144def main():
145    #unittest.main()
146    runner = unittest.TextTestRunner()
147    runner.run(suite())
148
149if __name__ == "__main__":
150    main()
Note: See TracBrowser for help on using the repository browser.