source: sasview/pr_inversion/num_term.py @ 9e85792

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 9e85792 was bd30f4a5, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

More improvements

  • Property mode set to 100644
File size: 5.9 KB
Line 
1import unittest, math, numpy, pylab, sys, string
2from sans.pr.invertor import Invertor
3
4class Num_terms():   
5   
6     def __init__(self, invertor):
7         self.invertor = invertor
8         self.nterm_min = 10
9         self.nterm_max = len(self.invertor.x)
10         self.isquit_func = None
11         
12         self.osc_list = []
13         self.err_list = []
14         self.alpha_list = []
15         self.mess_list = []
16         
17         self.dataset = []
18     
19     def is_odd(self, n):
20         return bool(n%2)
21
22     def sort_osc(self):
23         import copy
24         osc = copy.deepcopy(self.dataset)
25         lis = []
26         for i in range(len(osc)):
27             osc.sort()
28             re = osc.pop(0)
29             lis.append(re)
30         return lis
31           
32     def median_osc(self):
33         osc = self.sort_osc()
34         dv = len(osc)
35         med = float(dv) / 2.0
36         odd = self.is_odd(dv)
37         medi = 0
38         for i in range(dv):
39             if odd == True:
40                 medi = osc[int(med)]
41             else:
42                 medi = osc[int(med) - 1]
43         return medi
44
45     def get0_out(self):
46        inver = self.invertor
47        self.osc_list = []
48        self.err_list = []
49        self.alpha_list = []
50        for k in range(self.nterm_min, self.nterm_max):
51            if self.isquit_func != None:
52                self.isquit_func()
53            best_alpha, message, elapsed = inver.estimate_alpha(k)
54            inver.alpha = best_alpha
55            inver.out, inver.cov = inver.lstsq(k)
56            osc = inver.oscillations(inver.out)
57            err = inver.get_pos_err(inver.out, inver.cov)
58           
59            self.osc_list.append(osc)
60            self.err_list.append(err)
61            self.alpha_list.append(inver.alpha)
62            self.mess_list.append(message)
63         
64        #print "osc", self.osc_list
65        #print "err", self.err_list
66        #print "alpha", self.alpha_list
67        new_ls = []
68        new_osc1 = []
69        new_osc2= []
70        new_osc3 = []
71        flag9=False
72        flag8=False
73        flag7=False
74        for i in range(len(self.err_list)):
75            if self.err_list[i] <= 1.0 and self.err_list[i] >=0.9:
76                new_osc1.append(self.osc_list[i])
77                flag9=True
78            if self.err_list[i] < 0.9 and self.err_list[i] >=0.8:
79                new_osc2.append(self.osc_list[i])
80                flag8=True
81            if self.err_list[i] <0.8 and self.err_list[i] >= 0.7:
82                new_osc3.append(self.osc_list[i])
83                falg7=True
84                 
85        if flag9==True:
86            self.dataset = new_osc1
87        elif flag8==True:
88            self.dataset = new_osc2
89        else:
90            self.dataset = new_osc3
91         
92        #print "dataset", self.dataset
93        return self.dataset
94       
95     def ls_osc(self):
96        # Generate data
97        ls_osc = self.get0_out()
98        med = self.median_osc()
99       
100        #TODO: check 1
101        ls_osc = self.dataset
102        ls = []
103        for i in range(len(ls_osc)):
104            if int(med) == int(ls_osc[i]):
105                ls.append(ls_osc[i])
106        return ls
107
108     def compare_err(self):
109        ls = self.ls_osc()
110        #print "ls", ls
111        nt_ls = []
112        for i in range(len(ls)):
113            r = ls[i]
114            n = self.osc_list.index(r) + 10
115            #er = self.err_list[n]
116            #nt = self.osc_list.index(r) + 10
117            nt_ls.append(n)
118        #print "nt list", nt_ls
119        return nt_ls
120
121     def num_terms(self, isquit_func=None):
122         try:
123             self.isquit_func = isquit_func
124             #self.nterm_max = len(self.invertor.x)
125             #self.nterm_max = 32
126             nts = self.compare_err()
127             #print "nts", nts
128             div = len(nts)
129             tem = float(div)/2.0
130             odd = self.is_odd(div)
131             if odd == True:
132                 nt = nts[int(tem)]
133             else:
134                 nt = nts[int(tem) - 1]
135             return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
136         except:
137             return self.nterm_min, self.alpha_list[10], self.mess_list[10]
138
139#For testing
140def load(path):
141        import numpy, math, sys
142        # Read the data from the data file
143        data_x   = numpy.zeros(0)
144        data_y   = numpy.zeros(0)
145        data_err = numpy.zeros(0)
146        scale    = None
147        min_err  = 0.0
148        if not path == None:
149            input_f = open(path,'r')
150            buff    = input_f.read()
151            lines   = buff.split('\n')
152            for line in lines:
153                try:
154                    toks = line.split()
155                    x = float(toks[0])
156                    y = float(toks[1])
157                    if len(toks)>2:
158                        err = float(toks[2])
159                    else:
160                        if scale==None:
161                            scale = 0.05*math.sqrt(y)
162                            #scale = 0.05/math.sqrt(y)
163                            min_err = 0.01*y
164                        err = scale*math.sqrt(y)+min_err
165                        #err = 0
166                       
167                    data_x = numpy.append(data_x, x)
168                    data_y = numpy.append(data_y, y)
169                    data_err = numpy.append(data_err, err)
170                except:
171                    pass
172                   
173        return data_x, data_y, data_err
174   
175
176if __name__ == "__main__":
177    i = Invertor()
178    x, y, erro = load("test/Cyl_A_D102.txt")
179    i.d_max = 102.0
180    i.nfunc = 10   
181    #i.q_max = 0.4
182    #i.q_min = 0.07
183    i.x   = x
184    i.y   = y
185    i.err = erro   
186    #i.out, i.cov = i.lstsq(10)
187    # Testing estimator
188    est = Num_terms(i)
189    print est.num_terms()
190   
191       
192       
193       
194       
195       
Note: See TracBrowser for help on using the repository browser.