source: sasview/pr_inversion/num_term.py @ 35adaf6

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 35adaf6 was e96a852, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

From Raiza: n_term estimator

  • Property mode set to 100644
File size: 5.9 KB
Line 
1import unittest, math, numpy, pylab, sys, string
2from sans.pr.invertor import Invertor
3
4class Num_terms():   
5   
6     def __init__(self, invertor):
7         self.invertor = invertor
8         self.nterm_min = 10
9         self.nterm_max = len(self.invertor.x)
10         self.isquit_func = None
11         
12         self.osc_list = []
13         self.err_list = []
14         self.alpha_list = []
15         self.mess_list = []
16         
17         self.dataset = []
18     
19     def is_odd(self, n):
20         return bool(n%2)
21
22     def sort_osc(self):
23         import copy
24         osc = copy.deepcopy(self.dataset)
25         lis = []
26         for i in range(len(osc)):
27             osc.sort()
28             re = osc.pop(0)
29             lis.append(re)
30         return lis
31           
32     def median_osc(self):
33         osc = self.sort_osc()
34         dv = len(osc)
35         med = float(dv) / 2.0
36         odd = self.is_odd(dv)
37         medi = 0
38         for i in range(dv):
39             if odd == True:
40                 medi = osc[int(med)]
41             else:
42                 medi = osc[int(med) - 1]
43         return medi
44
45     def get0_out(self):
46        inver = self.invertor
47        self.osc_list = []
48        self.err_list = []
49        self.alpha_list = []
50        for k in range(self.nterm_min, self.nterm_max):
51            print "get_out ", k
52            if self.isquit_func != None:
53                self.isquit_func()
54            best_alpha, message, elapsed = inver.estimate_alpha(k)
55            inver.alpha = best_alpha
56            inver.out, inver.cov = inver.lstsq(k)
57            osc = inver.oscillations(inver.out)
58            err = inver.get_pos_err(inver.out, inver.cov)
59           
60            self.osc_list.append(osc)
61            self.err_list.append(err)
62            self.alpha_list.append(inver.alpha)
63            self.mess_list.append(message)
64         
65        #print "osc", self.osc_list
66        #print "err", self.err_list
67        #print "alpha", self.alpha_list
68        new_ls = []
69        new_osc1 = []
70        new_osc2= []
71        new_osc3 = []
72        flag9=False
73        flag8=False
74        flag7=False
75        for i in range(len(self.err_list)):
76            if self.err_list[i] <= 1.0 and self.err_list[i] >=0.9:
77                new_osc1.append(self.osc_list[i])
78                flag9=True
79            if self.err_list[i] < 0.9 and self.err_list[i] >=0.8:
80                new_osc2.append(self.osc_list[i])
81                flag8=True
82            if self.err_list[i] <0.8 and self.err_list[i] >= 0.7:
83                new_osc3.append(self.osc_list[i])
84                falg7=True
85                 
86        if flag9==True:
87            self.dataset = new_osc1
88        elif flag8==True:
89            self.dataset = new_osc2
90        else:
91            self.dataset = new_osc3
92         
93        #print "dataset", self.dataset
94        return self.dataset
95       
96     def ls_osc(self):
97        # Generate data
98        ls_osc = self.get0_out()
99        med = self.median_osc()
100       
101        #TODO: check 1
102        ls_osc = self.dataset
103        ls = []
104        for i in range(len(ls_osc)):
105            if int(med) == int(ls_osc[i]):
106                ls.append(ls_osc[i])
107        return ls
108
109     def compare_err(self):
110        ls = self.ls_osc()
111        #print "ls", ls
112        nt_ls = []
113        for i in range(len(ls)):
114            r = ls[i]
115            n = self.osc_list.index(r) + 10
116            #er = self.err_list[n]
117            #nt = self.osc_list.index(r) + 10
118            nt_ls.append(n)
119        print "nt list", nt_ls
120        return nt_ls
121
122     def num_terms(self, isquit_func=None):
123         try:
124             self.isquit_func = isquit_func
125             #self.nterm_max = len(self.invertor.x)
126             #self.nterm_max = 32
127             nts = self.compare_err()
128             #print "nts", nts
129             div = len(nts)
130             tem = float(div)/2.0
131             odd = self.is_odd(div)
132             if odd == True:
133                 nt = nts[int(tem)]
134             else:
135                 nt = nts[int(tem) - 1]
136             return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
137         except:
138             return self.nterm_min, self.alpha_list[10], self.mess_list[10]
139
140#For testing
141def load(path):
142        import numpy, math, sys
143        # Read the data from the data file
144        data_x   = numpy.zeros(0)
145        data_y   = numpy.zeros(0)
146        data_err = numpy.zeros(0)
147        scale    = None
148        min_err  = 0.0
149        if not path == None:
150            input_f = open(path,'r')
151            buff    = input_f.read()
152            lines   = buff.split('\n')
153            for line in lines:
154                try:
155                    toks = line.split()
156                    x = float(toks[0])
157                    y = float(toks[1])
158                    if len(toks)>2:
159                        err = float(toks[2])
160                    else:
161                        if scale==None:
162                            scale = 0.05*math.sqrt(y)
163                            #scale = 0.05/math.sqrt(y)
164                            min_err = 0.01*y
165                        err = scale*math.sqrt(y)+min_err
166                        #err = 0
167                       
168                    data_x = numpy.append(data_x, x)
169                    data_y = numpy.append(data_y, y)
170                    data_err = numpy.append(data_err, err)
171                except:
172                    pass
173                   
174        return data_x, data_y, data_err
175   
176
177if __name__ == "__main__":
178    i = Invertor()
179    x, y, erro = load("test/Cyl_A_D102.txt")
180    i.d_max = 102.0
181    i.nfunc = 10   
182    #i.q_max = 0.4
183    #i.q_min = 0.07
184    i.x   = x
185    i.y   = y
186    i.err = erro   
187    #i.out, i.cov = i.lstsq(10)
188    # Testing estimator
189    est = Num_terms(i)
190    print est.num_terms()
191   
192       
193       
194       
195       
196       
Note: See TracBrowser for help on using the repository browser.