source: sasview/src/sas/sascalc/pr/num_term.py @ a1b8fee

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since a1b8fee was a1b8fee, checked in by andyfaff, 7 years ago

MAINT: from future import print_function

  • Property mode set to 100644
File size: 5.3 KB
Line 
1from __future__ import print_function
2
3import math
4import numpy as np
5import copy
6import sys
7import logging
8from sas.sascalc.pr.invertor import Invertor
9
10logger = logging.getLogger(__name__)
11
12class NTermEstimator(object):
13    """
14    """
15    def __init__(self, invertor):
16        """
17        """
18        self.invertor = invertor
19        self.nterm_min = 10
20        self.nterm_max = len(self.invertor.x)
21        if self.nterm_max > 50:
22            self.nterm_max = 50
23        self.isquit_func = None
24
25        self.osc_list = []
26        self.err_list = []
27        self.alpha_list = []
28        self.mess_list = []
29        self.dataset = []
30
31    def is_odd(self, n):
32        """
33        """
34        return bool(n % 2)
35
36    def sort_osc(self):
37        """
38        """
39        #import copy
40        osc = copy.deepcopy(self.dataset)
41        lis = []
42        for i in range(len(osc)):
43            osc.sort()
44            re = osc.pop(0)
45            lis.append(re)
46        return lis
47
48    def median_osc(self):
49        """
50        """
51        osc = self.sort_osc()
52        dv = len(osc)
53        med = float(dv) / 2.0
54        odd = self.is_odd(dv)
55        medi = 0
56        for i in range(dv):
57            if odd == True:
58                medi = osc[int(med)]
59            else:
60                medi = osc[int(med) - 1]
61        return medi
62
63    def get0_out(self):
64        """
65        """
66        inver = self.invertor
67        self.osc_list = []
68        self.err_list = []
69        self.alpha_list = []
70        for k in range(self.nterm_min, self.nterm_max, 1):
71            if self.isquit_func is not None:
72                self.isquit_func()
73            best_alpha, message, _ = inver.estimate_alpha(k)
74            inver.alpha = best_alpha
75            inver.out, inver.cov = inver.lstsq(k)
76            osc = inver.oscillations(inver.out)
77            err = inver.get_pos_err(inver.out, inver.cov)
78            if osc > 10.0:
79                break
80            self.osc_list.append(osc)
81            self.err_list.append(err)
82            self.alpha_list.append(inver.alpha)
83            self.mess_list.append(message)
84
85        new_osc1 = []
86        new_osc2 = []
87        new_osc3 = []
88        flag9 = False
89        flag8 = False
90        for i in range(len(self.err_list)):
91            if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9:
92                new_osc1.append(self.osc_list[i])
93                flag9 = True
94            if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8:
95                new_osc2.append(self.osc_list[i])
96                flag8 = True
97            if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7:
98                new_osc3.append(self.osc_list[i])
99
100        if flag9 == True:
101            self.dataset = new_osc1
102        elif flag8 == True:
103            self.dataset = new_osc2
104        else:
105            self.dataset = new_osc3
106
107        return self.dataset
108
109    def ls_osc(self):
110        """
111        """
112        # Generate data
113        self.get0_out()
114        med = self.median_osc()
115
116        #TODO: check 1
117        ls_osc = self.dataset
118        ls = []
119        for i in range(len(ls_osc)):
120            if int(med) == int(ls_osc[i]):
121                ls.append(ls_osc[i])
122        return ls
123
124    def compare_err(self):
125        """
126        """
127        ls = self.ls_osc()
128        nt_ls = []
129        for i in range(len(ls)):
130            r = ls[i]
131            n = self.osc_list.index(r) + 10
132            nt_ls.append(n)
133        return nt_ls
134
135    def num_terms(self, isquit_func=None):
136        """
137        """
138        try:
139            self.isquit_func = isquit_func
140            nts = self.compare_err()
141            div = len(nts)
142            tem = float(div) / 2.0
143            odd = self.is_odd(div)
144            if odd == True:
145                nt = nts[int(tem)]
146            else:
147                nt = nts[int(tem) - 1]
148            return nt, self.alpha_list[nt - 10], self.mess_list[nt - 10]
149        except:
150            #TODO: check the logic above and make sure it doesn't
151            # rely on the try-except.
152            return self.nterm_min, self.invertor.alpha, ''
153
154
155#For testing
156def load(path):
157    # Read the data from the data file
158    data_x = np.zeros(0)
159    data_y = np.zeros(0)
160    data_err = np.zeros(0)
161    scale = None
162    min_err = 0.0
163    if path is not None:
164        input_f = open(path, 'r')
165        buff = input_f.read()
166        lines = buff.split('\n')
167        for line in lines:
168            try:
169                toks = line.split()
170                test_x = float(toks[0])
171                test_y = float(toks[1])
172                if len(toks) > 2:
173                    err = float(toks[2])
174                else:
175                    if scale is None:
176                        scale = 0.05 * math.sqrt(test_y)
177                        #scale = 0.05/math.sqrt(y)
178                        min_err = 0.01 * y
179                    err = scale * math.sqrt(test_y) + min_err
180                    #err = 0
181
182                data_x = np.append(data_x, test_x)
183                data_y = np.append(data_y, test_y)
184                data_err = np.append(data_err, err)
185            except:
186                logger.error(sys.exc_value)
187
188    return data_x, data_y, data_err
189
190
191if __name__ == "__main__":
192    invert = Invertor()
193    x, y, erro = load("test/Cyl_A_D102.txt")
194    invert.d_max = 102.0
195    invert.nfunc = 10
196    invert.x = x
197    invert.y = y
198    invert.err = erro
199    # Testing estimator
200    est = NTermEstimator(invert)
201    print(est.num_terms())
Note: See TracBrowser for help on using the repository browser.