1 | """ |
---|
2 | Unit tests for DistExplorer class |
---|
3 | """ |
---|
4 | |
---|
5 | import sys |
---|
6 | import os.path |
---|
7 | import unittest, math, numpy |
---|
8 | from sas.sascalc.pr.invertor import Invertor |
---|
9 | from sas.sascalc.pr.distance_explorer import DistExplorer |
---|
10 | |
---|
11 | def find(filename): |
---|
12 | return os.path.join(os.path.dirname(__file__), filename) |
---|
13 | |
---|
14 | |
---|
15 | class TestExplorer(unittest.TestCase): |
---|
16 | |
---|
17 | def setUp(self): |
---|
18 | self.invertor = Invertor() |
---|
19 | x, y, err = load(find('sphere_80.txt')) |
---|
20 | |
---|
21 | # Choose the right d_max... |
---|
22 | self.invertor.d_max = 160.0 |
---|
23 | # Set a small alpha |
---|
24 | self.invertor.alpha = .0007 |
---|
25 | # Set data |
---|
26 | self.invertor.x = x |
---|
27 | self.invertor.y = y |
---|
28 | self.invertor.err = err |
---|
29 | self.invertor.nfunc = 15 |
---|
30 | |
---|
31 | self.explo = DistExplorer(self.invertor) |
---|
32 | |
---|
33 | def test_exploration(self): |
---|
34 | results = self.explo(120, 200, 25) |
---|
35 | self.assertEqual(len(results.errors), 0) |
---|
36 | self.assertEqual(len(results.chi2), 25) |
---|
37 | |
---|
38 | |
---|
39 | # Note: duplicated from utest_invertor because the following failed: |
---|
40 | #from .utest_invertor import load |
---|
41 | def load(path = "sphere_60_q0_2.txt"): |
---|
42 | import numpy as np |
---|
43 | import math |
---|
44 | import sys |
---|
45 | # Read the data from the data file |
---|
46 | data_x = np.zeros(0) |
---|
47 | data_y = np.zeros(0) |
---|
48 | data_err = np.zeros(0) |
---|
49 | scale = None |
---|
50 | if path is not None: |
---|
51 | input_f = open(path,'r') |
---|
52 | buff = input_f.read() |
---|
53 | lines = buff.split('\n') |
---|
54 | for line in lines: |
---|
55 | try: |
---|
56 | toks = line.split() |
---|
57 | x = float(toks[0]) |
---|
58 | y = float(toks[1]) |
---|
59 | if len(toks)>2: |
---|
60 | err = float(toks[2]) |
---|
61 | else: |
---|
62 | if scale==None: |
---|
63 | scale = 0.15*math.sqrt(y) |
---|
64 | err = scale*math.sqrt(y) |
---|
65 | data_x = np.append(data_x, x) |
---|
66 | data_y = np.append(data_y, y) |
---|
67 | data_err = np.append(data_err, err) |
---|
68 | except: |
---|
69 | pass |
---|
70 | |
---|
71 | return data_x, data_y, data_err |
---|
72 | |
---|
73 | if __name__ == '__main__': |
---|
74 | unittest.main() |
---|