source: sasmodels/sasmodels/model_test.py @ 5ca9762

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 5ca9762 was 5ca9762, checked in by pkienzle, 9 years ago

make sure models run even if no verified values are provided; fix error reporting on tests

  • Property mode set to 100644
File size: 4.6 KB
Line 
1# -*- coding: utf-8 -*-
2"""
3Created on Tue Feb 17 11:43:56 2015
4
5@author: David
6"""
7
8import unittest
9import warnings
10import numpy as np
11
12from os.path import basename, dirname, join as joinpath
13from glob import glob
14
15try:
16    from .kernelcl import load_model
17except ImportError,exc:
18    warnings.warn(str(exc))
19    warnings.warn("using ctypes instead")
20    from .kerneldll import load_model
21
22def load_kernel(model, dtype='single'):   
23    kernel = load_model(model, dtype=dtype)
24    kernel.info['defaults'] = dict((p[0],p[2]) for p in kernel.info['parameters'])
25    return kernel
26
27def get_weights(model, pars, name):
28    from . import weights
29   
30    relative = name in model.info['partype']['pd-rel']
31    disperser = pars.get(name+"_pd_type", "gaussian")
32    value = pars.get(name, model.info['defaults'][name])
33    width = pars.get(name+"_pd", 0.0)
34    npts = pars.get(name+"_pd_n", 30)
35    nsigma = pars.get(name+"_pd_nsigma", 3.0)
36    v,w = weights.get_weights(
37            disperser, npts, width, nsigma,
38            value, model.info['limits'][name], relative)
39    return v,w/np.sum(w)
40
41def eval_kernel(kernel, q, pars, cutoff=1e-5):
42    input = kernel.make_input(q)
43    finput = kernel(input)
44
45    fixed_pars = [pars.get(name, finput.info['defaults'][name])
46                  for name in finput.fixed_pars]
47    pd_pars = [get_weights(finput, pars, p) for p in finput.pd_pars]
48    return finput(fixed_pars, pd_pars, cutoff)
49
50def annotate_exception(exc, msg):
51    """
52    Add an annotation to the current exception, which can then be forwarded
53    to the caller using a bare "raise" statement to reraise the annotated
54    exception.
55    Example::
56        >>> D = {}
57        >>> try:
58        ...    print D['hello']
59        ... except Exception,exc:
60        ...    annotate_exception(exc, "while accessing 'D'")
61        ...    raise
62        Traceback (most recent call last):
63            ...
64        KeyError: "hello while accessing 'D'"
65    """
66    args = exc.args
67    if not args:
68        exc.args = (msg,)
69    else:
70        try:
71            arg0 = " ".join((args[0],msg))
72            exc.args = tuple([arg0] + list(args[1:]))
73        except:
74            exc.args = (" ".join((str(exc),msg)),)
75   
76def suite():
77    root = dirname(__file__)
78    files = sorted(glob(joinpath(root, 'models', "[a-zA-Z]*.py")))
79    models_names = [basename(f)[:-3] for f in files]
80   
81    suite = unittest.TestSuite()
82   
83    for model_name in models_names:
84        module = __import__('sasmodels.models.' + model_name)
85        module = getattr(module, 'models', None)
86
87        model = getattr(module, model_name, None)
88        smoke_tests = [[{},0.1,None],[{},(0.1,0.1),None]]
89        tests = smoke_tests + getattr(model, 'tests', [])
90       
91        if tests:
92            #print '------'
93            #print 'found tests in', model_name
94            #print '------'
95   
96            kernel = load_kernel(model)
97            suite.addTest(ModelTestCase(model_name, kernel, tests))
98
99    return suite
100
101class ModelTestCase(unittest.TestCase):
102   
103    def __init__(self, model_name, kernel, tests):
104        unittest.TestCase.__init__(self)
105       
106        self.model_name = model_name
107        self.kernel = kernel
108        self.tests = tests
109
110    def runTest(self):
111        #print '------'
112        #print self.model_name
113        #print '------'
114        try:
115            for test in self.tests:
116                params = test[0]
117                Q = test[1]
118                I = test[2]
119                     
120                if not isinstance(Q, list):
121                    Q = [Q]
122                if not isinstance(I, list):
123                    I = [I]
124                   
125                if isinstance(Q[0], tuple):
126                    Qx,Qy = zip(*Q)
127                    Q_vectors = [np.array(Qx), np.array(Qy)]
128                else:
129                    Q_vectors = [np.array(Q)]
130
131                self.assertEqual(len(I), len(Q))
132           
133                Iq = eval_kernel(self.kernel, Q_vectors, params)
134           
135                self.assertGreater(len(Iq), 0)   
136                self.assertEqual(len(I), len(Iq))             
137               
138                for q, i, iq in zip(Q, I, Iq):
139                    if i is None: continue # smoke test --- make sure it runs
140                    err = np.abs(i - iq)
141                    nrm = np.abs(i)
142           
143                    self.assertLess(err * 10**5, nrm, 'q:%s; expected:%s; actual:%s' % (q, i, iq))
144                   
145        except Exception,exc: 
146            annotate_exception(exc, '\r\nModel: %s' % self.model_name)
147            raise
148
149def main():
150    #unittest.main()
151    runner = unittest.TextTestRunner()
152    runner.run(suite())
153
154if __name__ == "__main__":
155    main()
Note: See TracBrowser for help on using the repository browser.