source: sasview/pr_inversion/num_term.py @ 8b6f489

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 8b6f489 was 1db4a53, checked in by Gervaise Alina <gervyh@…>, 14 years ago

working on pylint

  • Property mode set to 100644
File size: 6.1 KB
RevLine 
[1db4a53]1
2#import unittest
3import math
4import numpy
5#import sys
6#import string
7import copy
[e96a852]8from sans.pr.invertor import Invertor
9
10class Num_terms():   
[d84a90c]11    """
12    """
13    def __init__(self, invertor):
14        """
15        """
16        self.invertor = invertor
17        self.nterm_min = 10
18        self.nterm_max = len(self.invertor.x)
[1db4a53]19        if self.nterm_max > 50:
20            self.nterm_max = 50
[d84a90c]21        self.isquit_func = None
[e96a852]22         
[d84a90c]23        self.osc_list = []
24        self.err_list = []
25        self.alpha_list = []
26        self.mess_list = []
[e96a852]27         
[d84a90c]28        self.dataset = []
[e96a852]29     
[d84a90c]30    def is_odd(self, n):
31        """
32        """
[1db4a53]33        return bool(n % 2)
[e96a852]34
[d84a90c]35    def sort_osc(self):
36        """
37        """
[1db4a53]38        #import copy
[d84a90c]39        osc = copy.deepcopy(self.dataset)
40        lis = []
41        for i in range(len(osc)):
42            osc.sort()
43            re = osc.pop(0)
44            lis.append(re)
45        return lis
[e96a852]46           
[d84a90c]47    def median_osc(self):
48        """
49        """
50        osc = self.sort_osc()
51        dv = len(osc)
52        med = float(dv) / 2.0
53        odd = self.is_odd(dv)
54        medi = 0
55        for i in range(dv):
56            if odd == True:
57                medi = osc[int(med)]
58            else:
59                medi = osc[int(med) - 1]
60        return medi
[e96a852]61
[d84a90c]62    def get0_out(self):
63        """
64        """ 
[e96a852]65        inver = self.invertor
66        self.osc_list = []
67        self.err_list = []
68        self.alpha_list = []
[7578961]69        for k in range(self.nterm_min, self.nterm_max, 1):
[e96a852]70            if self.isquit_func != None:
71                self.isquit_func()
72            best_alpha, message, elapsed = inver.estimate_alpha(k)
73            inver.alpha = best_alpha
74            inver.out, inver.cov = inver.lstsq(k)
75            osc = inver.oscillations(inver.out)
76            err = inver.get_pos_err(inver.out, inver.cov)
[1db4a53]77            if osc > 10.0:
[7578961]78                break
[e96a852]79            self.osc_list.append(osc)
80            self.err_list.append(err)
81            self.alpha_list.append(inver.alpha)
82            self.mess_list.append(message)
83         
84        #print "osc", self.osc_list
85        #print "err", self.err_list
86        #print "alpha", self.alpha_list
[1db4a53]87        #new_ls = []
[e96a852]88        new_osc1 = []
[1db4a53]89        new_osc2 = []
[e96a852]90        new_osc3 = []
[1db4a53]91        flag9 = False
92        flag8 = False
93        flag7 = False
[e96a852]94        for i in range(len(self.err_list)):
[1db4a53]95            if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9:
[e96a852]96                new_osc1.append(self.osc_list[i])
[1db4a53]97                flag9 = True
98            if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8:
[e96a852]99                new_osc2.append(self.osc_list[i])
[1db4a53]100                flag8 = True
101            if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7:
[e96a852]102                new_osc3.append(self.osc_list[i])
[1db4a53]103                falg7 = True
[e96a852]104                 
[1db4a53]105        if flag9 == True:
[e96a852]106            self.dataset = new_osc1
[1db4a53]107        elif flag8 == True:
[e96a852]108            self.dataset = new_osc2
109        else:
110            self.dataset = new_osc3
111         
112        #print "dataset", self.dataset
113        return self.dataset
114       
[d84a90c]115    def ls_osc(self):
116        """
117        """
[e96a852]118        # Generate data
119        ls_osc = self.get0_out()
120        med = self.median_osc()
121       
122        #TODO: check 1
123        ls_osc = self.dataset
124        ls = []
125        for i in range(len(ls_osc)):
126            if int(med) == int(ls_osc[i]):
127                ls.append(ls_osc[i])
128        return ls
129
[d84a90c]130    def compare_err(self):
131        """
132        """
[e96a852]133        ls = self.ls_osc()
134        #print "ls", ls
135        nt_ls = []
136        for i in range(len(ls)):
137            r = ls[i]
138            n = self.osc_list.index(r) + 10
139            #er = self.err_list[n]
140            #nt = self.osc_list.index(r) + 10
141            nt_ls.append(n)
[bd30f4a5]142        #print "nt list", nt_ls
[e96a852]143        return nt_ls
144
[d84a90c]145    def num_terms(self, isquit_func=None):
146        """
147        """
148        try:
149            self.isquit_func = isquit_func
150            #self.nterm_max = len(self.invertor.x)
151            #self.nterm_max = 32
152            nts = self.compare_err()
153            #print "nts", nts
154            div = len(nts)
155            tem = float(div)/2.0
156            odd = self.is_odd(div)
157            if odd == True:
158                nt = nts[int(tem)]
159            else:
160                nt = nts[int(tem) - 1]
161            return nt, self.alpha_list[nt - 10], self.mess_list[nt-10]
162        except:
163            return self.nterm_min, self.alpha_list[10], self.mess_list[10]
[e96a852]164
165#For testing
166def load(path):
[1db4a53]167    #import numpy, math, sys
[d84a90c]168    # Read the data from the data file
169    data_x   = numpy.zeros(0)
170    data_y   = numpy.zeros(0)
171    data_err = numpy.zeros(0)
172    scale    = None
173    min_err  = 0.0
174    if not path == None:
175        input_f = open(path,'r')
176        buff    = input_f.read()
177        lines   = buff.split('\n')
178        for line in lines:
179            try:
180                toks = line.split()
[1db4a53]181                test_x = float(toks[0])
182                test_y = float(toks[1])
183                if len(toks) > 2:
[d84a90c]184                    err = float(toks[2])
185                else:
[1db4a53]186                    if scale == None:
187                        scale = 0.05 * math.sqrt(test_y)
[d84a90c]188                        #scale = 0.05/math.sqrt(y)
189                        min_err = 0.01*y
[1db4a53]190                    err = scale * math.sqrt(test_y) + min_err
[d84a90c]191                    #err = 0
192                   
[1db4a53]193                data_x = numpy.append(data_x, test_x)
194                data_y = numpy.append(data_y, test_y)
[d84a90c]195                data_err = numpy.append(data_err, err)
196            except:
197                pass
198               
199    return data_x, data_y, data_err
200
[e96a852]201
202if __name__ == "__main__":
203    i = Invertor()
204    x, y, erro = load("test/Cyl_A_D102.txt")
205    i.d_max = 102.0
206    i.nfunc = 10   
207    #i.q_max = 0.4
208    #i.q_min = 0.07
209    i.x   = x
210    i.y   = y
211    i.err = erro   
212    #i.out, i.cov = i.lstsq(10)
213    # Testing estimator
214    est = Num_terms(i)
215    print est.num_terms()
216   
217       
218       
219       
220       
221       
Note: See TracBrowser for help on using the repository browser.