source: sasview/pr_inversion/num_term.py @ c09ac449

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since c09ac449 was 7578961, checked in by Mathieu Doucet <doucetm@…>, 16 years ago

Small improvements

  • Property mode set to 100644
File size: 6.0 KB
Line 
1import unittest, math, numpy, pylab, sys, string
2from sans.pr.invertor import Invertor
3
4class Num_terms():   
5   
6     def __init__(self, invertor):
7         self.invertor = invertor
8         self.nterm_min = 10
9         self.nterm_max = len(self.invertor.x)
10         if self.nterm_max>50:
11             self.nterm_max=50
12         self.isquit_func = None
13         
14         self.osc_list = []
15         self.err_list = []
16         self.alpha_list = []
17         self.mess_list = []
18         
19         self.dataset = []
20     
21     def is_odd(self, n):
22         return bool(n%2)
23
24     def sort_osc(self):
25         import copy
26         osc = copy.deepcopy(self.dataset)
27         lis = []
28         for i in range(len(osc)):
29             osc.sort()
30             re = osc.pop(0)
31             lis.append(re)
32         return lis
33           
34     def median_osc(self):
35         osc = self.sort_osc()
36         dv = len(osc)
37         med = float(dv) / 2.0
38         odd = self.is_odd(dv)
39         medi = 0
40         for i in range(dv):
41             if odd == True:
42                 medi = osc[int(med)]
43             else:
44                 medi = osc[int(med) - 1]
45         return medi
46
47     def get0_out(self):
48        inver = self.invertor
49        self.osc_list = []
50        self.err_list = []
51        self.alpha_list = []
52        for k in range(self.nterm_min, self.nterm_max, 1):
53            if self.isquit_func != None:
54                self.isquit_func()
55            best_alpha, message, elapsed = inver.estimate_alpha(k)
56            inver.alpha = best_alpha
57            inver.out, inver.cov = inver.lstsq(k)
58            osc = inver.oscillations(inver.out)
59            err = inver.get_pos_err(inver.out, inver.cov)
60            if osc>10.0:
61                break
62            self.osc_list.append(osc)
63            self.err_list.append(err)
64            self.alpha_list.append(inver.alpha)
65            self.mess_list.append(message)
66         
67        #print "osc", self.osc_list
68        #print "err", self.err_list
69        #print "alpha", self.alpha_list
70        new_ls = []
71        new_osc1 = []
72        new_osc2= []
73        new_osc3 = []
74        flag9=False
75        flag8=False
76        flag7=False
77        for i in range(len(self.err_list)):
78            if self.err_list[i] <= 1.0 and self.err_list[i] >=0.9:
79                new_osc1.append(self.osc_list[i])
80                flag9=True
81            if self.err_list[i] < 0.9 and self.err_list[i] >=0.8:
82                new_osc2.append(self.osc_list[i])
83                flag8=True
84            if self.err_list[i] <0.8 and self.err_list[i] >= 0.7:
85                new_osc3.append(self.osc_list[i])
86                falg7=True
87                 
88        if flag9==True:
89            self.dataset = new_osc1
90        elif flag8==True:
91            self.dataset = new_osc2
92        else:
93            self.dataset = new_osc3
94         
95        #print "dataset", self.dataset
96        return self.dataset
97       
98     def ls_osc(self):
99        # Generate data
100        ls_osc = self.get0_out()
101        med = self.median_osc()
102       
103        #TODO: check 1
104        ls_osc = self.dataset
105        ls = []
106        for i in range(len(ls_osc)):
107            if int(med) == int(ls_osc[i]):
108                ls.append(ls_osc[i])
109        return ls
110
111     def compare_err(self):
112        ls = self.ls_osc()
113        #print "ls", ls
114        nt_ls = []
115        for i in range(len(ls)):
116            r = ls[i]
117            n = self.osc_list.index(r) + 10
118            #er = self.err_list[n]
119            #nt = self.osc_list.index(r) + 10
120            nt_ls.append(n)
121        #print "nt list", nt_ls
122        return nt_ls
123
124     def num_terms(self, isquit_func=None):
125         try:
126             self.isquit_func = isquit_func
127             #self.nterm_max = len(self.invertor.x)
128             #self.nterm_max = 32
129             nts = self.compare_err()
130             #print "nts", nts
131             div = len(nts)
132             tem = float(div)/2.0
133             odd = self.is_odd(div)
134             if odd == True:
135                 nt = nts[int(tem)]
136             else:
137                 nt = nts[int(tem) - 1]
138             return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
139         except:
140             return self.nterm_min, self.alpha_list[10], self.mess_list[10]
141
142#For testing
143def load(path):
144        import numpy, math, sys
145        # Read the data from the data file
146        data_x   = numpy.zeros(0)
147        data_y   = numpy.zeros(0)
148        data_err = numpy.zeros(0)
149        scale    = None
150        min_err  = 0.0
151        if not path == None:
152            input_f = open(path,'r')
153            buff    = input_f.read()
154            lines   = buff.split('\n')
155            for line in lines:
156                try:
157                    toks = line.split()
158                    x = float(toks[0])
159                    y = float(toks[1])
160                    if len(toks)>2:
161                        err = float(toks[2])
162                    else:
163                        if scale==None:
164                            scale = 0.05*math.sqrt(y)
165                            #scale = 0.05/math.sqrt(y)
166                            min_err = 0.01*y
167                        err = scale*math.sqrt(y)+min_err
168                        #err = 0
169                       
170                    data_x = numpy.append(data_x, x)
171                    data_y = numpy.append(data_y, y)
172                    data_err = numpy.append(data_err, err)
173                except:
174                    pass
175                   
176        return data_x, data_y, data_err
177   
178
179if __name__ == "__main__":
180    i = Invertor()
181    x, y, erro = load("test/Cyl_A_D102.txt")
182    i.d_max = 102.0
183    i.nfunc = 10   
184    #i.q_max = 0.4
185    #i.q_min = 0.07
186    i.x   = x
187    i.y   = y
188    i.err = erro   
189    #i.out, i.cov = i.lstsq(10)
190    # Testing estimator
191    est = Num_terms(i)
192    print est.num_terms()
193   
194       
195       
196       
197       
198       
Note: See TracBrowser for help on using the repository browser.