source: sasview/pr_inversion/src/sans/pr/num_term.py @ d560a37

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since d560a37 was 34f3ad0, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

make pylint happier

  • Property mode set to 100644
File size: 5.7 KB
Line 
1import math
2import numpy
3import copy
4from sans.pr.invertor import Invertor
5
6
7class Num_terms():
8    """
9    """
10    def __init__(self, invertor):
11        """
12        """
13        self.invertor = invertor
14        self.nterm_min = 10
15        self.nterm_max = len(self.invertor.x)
16        if self.nterm_max > 50:
17            self.nterm_max = 50
18        self.isquit_func = None
19         
20        self.osc_list = []
21        self.err_list = []
22        self.alpha_list = []
23        self.mess_list = []
24         
25        self.dataset = []
26     
27    def is_odd(self, n):
28        """
29        """
30        return bool(n % 2)
31
32    def sort_osc(self):
33        """
34        """
35        #import copy
36        osc = copy.deepcopy(self.dataset)
37        lis = []
38        for i in range(len(osc)):
39            osc.sort()
40            re = osc.pop(0)
41            lis.append(re)
42        return lis
43           
44    def median_osc(self):
45        """
46        """
47        osc = self.sort_osc()
48        dv = len(osc)
49        med = float(dv) / 2.0
50        odd = self.is_odd(dv)
51        medi = 0
52        for i in range(dv):
53            if odd == True:
54                medi = osc[int(med)]
55            else:
56                medi = osc[int(med) - 1]
57        return medi
58
59    def get0_out(self):
60        """
61        """
62        inver = self.invertor
63        self.osc_list = []
64        self.err_list = []
65        self.alpha_list = []
66        for k in range(self.nterm_min, self.nterm_max, 1):
67            if self.isquit_func != None:
68                self.isquit_func()
69            best_alpha, message, _ = inver.estimate_alpha(k)
70            inver.alpha = best_alpha
71            inver.out, inver.cov = inver.lstsq(k)
72            osc = inver.oscillations(inver.out)
73            err = inver.get_pos_err(inver.out, inver.cov)
74            if osc > 10.0:
75                break
76            self.osc_list.append(osc)
77            self.err_list.append(err)
78            self.alpha_list.append(inver.alpha)
79            self.mess_list.append(message)
80         
81        new_osc1 = []
82        new_osc2 = []
83        new_osc3 = []
84        flag9 = False
85        flag8 = False
86        flag7 = False
87        for i in range(len(self.err_list)):
88            if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9:
89                new_osc1.append(self.osc_list[i])
90                flag9 = True
91            if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8:
92                new_osc2.append(self.osc_list[i])
93                flag8 = True
94            if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7:
95                new_osc3.append(self.osc_list[i])
96                flag7 = True
97                 
98        if flag9 == True:
99            self.dataset = new_osc1
100        elif flag8 == True:
101            self.dataset = new_osc2
102        else:
103            self.dataset = new_osc3
104         
105        return self.dataset
106       
107    def ls_osc(self):
108        """
109        """
110        # Generate data
111        ls_osc = self.get0_out()
112        med = self.median_osc()
113       
114        #TODO: check 1
115        ls_osc = self.dataset
116        ls = []
117        for i in range(len(ls_osc)):
118            if int(med) == int(ls_osc[i]):
119                ls.append(ls_osc[i])
120        return ls
121
122    def compare_err(self):
123        """
124        """
125        ls = self.ls_osc()
126        #print "ls", ls
127        nt_ls = []
128        for i in range(len(ls)):
129            r = ls[i]
130            n = self.osc_list.index(r) + 10
131            #er = self.err_list[n]
132            #nt = self.osc_list.index(r) + 10
133            nt_ls.append(n)
134        #print "nt list", nt_ls
135        return nt_ls
136
137    def num_terms(self, isquit_func=None):
138        """
139        """
140        try:
141            self.isquit_func = isquit_func
142            nts = self.compare_err()
143            div = len(nts)
144            tem = float(div)/2.0
145            odd = self.is_odd(div)
146            if odd == True:
147                nt = nts[int(tem)]
148            else:
149                nt = nts[int(tem) - 1]
150            return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
151        except:
152            return self.nterm_min, self.alpha_list[10], self.mess_list[10]
153
154
155#For testing
156def load(path):
157    # Read the data from the data file
158    data_x   = numpy.zeros(0)
159    data_y   = numpy.zeros(0)
160    data_err = numpy.zeros(0)
161    scale    = None
162    min_err  = 0.0
163    if not path == None:
164        input_f = open(path,'r')
165        buff    = input_f.read()
166        lines   = buff.split('\n')
167        for line in lines:
168            try:
169                toks = line.split()
170                test_x = float(toks[0])
171                test_y = float(toks[1])
172                if len(toks) > 2:
173                    err = float(toks[2])
174                else:
175                    if scale == None:
176                        scale = 0.05 * math.sqrt(test_y)
177                        #scale = 0.05/math.sqrt(y)
178                        min_err = 0.01 * y
179                    err = scale * math.sqrt(test_y) + min_err
180                    #err = 0
181                   
182                data_x = numpy.append(data_x, test_x)
183                data_y = numpy.append(data_y, test_y)
184                data_err = numpy.append(data_err, err)
185            except:
186                pass
187               
188    return data_x, data_y, data_err
189
190
191if __name__ == "__main__":
192    i = Invertor()
193    x, y, erro = load("test/Cyl_A_D102.txt")
194    i.d_max = 102.0
195    i.nfunc = 10
196    #i.q_max = 0.4
197    #i.q_min = 0.07
198    i.x = x
199    i.y = y
200    i.err = erro
201    #i.out, i.cov = i.lstsq(10)
202    # Testing estimator
203    est = Num_terms(i)
204    print est.num_terms()
Note: See TracBrowser for help on using the repository browser.