source: sasview/pr_inversion/num_term.py @ 6d48919

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 6d48919 was d84a90c, checked in by Gervaise Alina <gervyh@…>, 15 years ago

working on documentation

  • Property mode set to 100644
File size: 6.0 KB
Line 
1import unittest, math, numpy, sys, string
2from sans.pr.invertor import Invertor
3
4class Num_terms():   
5    """
6    """
7    def __init__(self, invertor):
8        """
9        """
10        self.invertor = invertor
11        self.nterm_min = 10
12        self.nterm_max = len(self.invertor.x)
13        if self.nterm_max>50:
14            self.nterm_max=50
15        self.isquit_func = None
16         
17        self.osc_list = []
18        self.err_list = []
19        self.alpha_list = []
20        self.mess_list = []
21         
22        self.dataset = []
23     
24    def is_odd(self, n):
25        """
26        """
27        return bool(n%2)
28
29    def sort_osc(self):
30        """
31        """
32        import copy
33        osc = copy.deepcopy(self.dataset)
34        lis = []
35        for i in range(len(osc)):
36            osc.sort()
37            re = osc.pop(0)
38            lis.append(re)
39        return lis
40           
41    def median_osc(self):
42        """
43        """
44        osc = self.sort_osc()
45        dv = len(osc)
46        med = float(dv) / 2.0
47        odd = self.is_odd(dv)
48        medi = 0
49        for i in range(dv):
50            if odd == True:
51                medi = osc[int(med)]
52            else:
53                medi = osc[int(med) - 1]
54        return medi
55
56    def get0_out(self):
57        """
58        """ 
59        inver = self.invertor
60        self.osc_list = []
61        self.err_list = []
62        self.alpha_list = []
63        for k in range(self.nterm_min, self.nterm_max, 1):
64            if self.isquit_func != None:
65                self.isquit_func()
66            best_alpha, message, elapsed = inver.estimate_alpha(k)
67            inver.alpha = best_alpha
68            inver.out, inver.cov = inver.lstsq(k)
69            osc = inver.oscillations(inver.out)
70            err = inver.get_pos_err(inver.out, inver.cov)
71            if osc>10.0:
72                break
73            self.osc_list.append(osc)
74            self.err_list.append(err)
75            self.alpha_list.append(inver.alpha)
76            self.mess_list.append(message)
77         
78        #print "osc", self.osc_list
79        #print "err", self.err_list
80        #print "alpha", self.alpha_list
81        new_ls = []
82        new_osc1 = []
83        new_osc2= []
84        new_osc3 = []
85        flag9=False
86        flag8=False
87        flag7=False
88        for i in range(len(self.err_list)):
89            if self.err_list[i] <= 1.0 and self.err_list[i] >=0.9:
90                new_osc1.append(self.osc_list[i])
91                flag9=True
92            if self.err_list[i] < 0.9 and self.err_list[i] >=0.8:
93                new_osc2.append(self.osc_list[i])
94                flag8=True
95            if self.err_list[i] <0.8 and self.err_list[i] >= 0.7:
96                new_osc3.append(self.osc_list[i])
97                falg7=True
98                 
99        if flag9==True:
100            self.dataset = new_osc1
101        elif flag8==True:
102            self.dataset = new_osc2
103        else:
104            self.dataset = new_osc3
105         
106        #print "dataset", self.dataset
107        return self.dataset
108       
109    def ls_osc(self):
110        """
111        """
112        # Generate data
113        ls_osc = self.get0_out()
114        med = self.median_osc()
115       
116        #TODO: check 1
117        ls_osc = self.dataset
118        ls = []
119        for i in range(len(ls_osc)):
120            if int(med) == int(ls_osc[i]):
121                ls.append(ls_osc[i])
122        return ls
123
124    def compare_err(self):
125        """
126        """
127        ls = self.ls_osc()
128        #print "ls", ls
129        nt_ls = []
130        for i in range(len(ls)):
131            r = ls[i]
132            n = self.osc_list.index(r) + 10
133            #er = self.err_list[n]
134            #nt = self.osc_list.index(r) + 10
135            nt_ls.append(n)
136        #print "nt list", nt_ls
137        return nt_ls
138
139    def num_terms(self, isquit_func=None):
140        """
141        """
142        try:
143            self.isquit_func = isquit_func
144            #self.nterm_max = len(self.invertor.x)
145            #self.nterm_max = 32
146            nts = self.compare_err()
147            #print "nts", nts
148            div = len(nts)
149            tem = float(div)/2.0
150            odd = self.is_odd(div)
151            if odd == True:
152                nt = nts[int(tem)]
153            else:
154                nt = nts[int(tem) - 1]
155            return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
156        except:
157            return self.nterm_min, self.alpha_list[10], self.mess_list[10]
158
159#For testing
160def load(path):
161    import numpy, math, sys
162    # Read the data from the data file
163    data_x   = numpy.zeros(0)
164    data_y   = numpy.zeros(0)
165    data_err = numpy.zeros(0)
166    scale    = None
167    min_err  = 0.0
168    if not path == None:
169        input_f = open(path,'r')
170        buff    = input_f.read()
171        lines   = buff.split('\n')
172        for line in lines:
173            try:
174                toks = line.split()
175                x = float(toks[0])
176                y = float(toks[1])
177                if len(toks)>2:
178                    err = float(toks[2])
179                else:
180                    if scale==None:
181                        scale = 0.05*math.sqrt(y)
182                        #scale = 0.05/math.sqrt(y)
183                        min_err = 0.01*y
184                    err = scale*math.sqrt(y)+min_err
185                    #err = 0
186                   
187                data_x = numpy.append(data_x, x)
188                data_y = numpy.append(data_y, y)
189                data_err = numpy.append(data_err, err)
190            except:
191                pass
192               
193    return data_x, data_y, data_err
194
195
196if __name__ == "__main__":
197    i = Invertor()
198    x, y, erro = load("test/Cyl_A_D102.txt")
199    i.d_max = 102.0
200    i.nfunc = 10   
201    #i.q_max = 0.4
202    #i.q_min = 0.07
203    i.x   = x
204    i.y   = y
205    i.err = erro   
206    #i.out, i.cov = i.lstsq(10)
207    # Testing estimator
208    est = Num_terms(i)
209    print est.num_terms()
210   
211       
212       
213       
214       
215       
Note: See TracBrowser for help on using the repository browser.