source: sasview/src/sas/sascalc/pr/num_term.py @ 87d44c7

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 87d44c7 was 9c0f3c17, checked in by Ricardo Ferraz Leal <ricleal@…>, 8 years ago

After merge conflict

  • Property mode set to 100644
File size: 5.5 KB
RevLine 
[1db4a53]1import math
[9a5097c]2import numpy as np
[1db4a53]3import copy
[5f8fc78]4import sys
5import logging
[b699768]6from sas.sascalc.pr.invertor import Invertor
[e96a852]7
[463e7ffc]8logger = logging.getLogger(__name__)
[c155a16]9
[5f8fc78]10class NTermEstimator(object):
[d84a90c]11    """
12    """
13    def __init__(self, invertor):
14        """
15        """
16        self.invertor = invertor
17        self.nterm_min = 10
18        self.nterm_max = len(self.invertor.x)
[1db4a53]19        if self.nterm_max > 50:
20            self.nterm_max = 50
[d84a90c]21        self.isquit_func = None
[038c00cf]22
[d84a90c]23        self.osc_list = []
24        self.err_list = []
25        self.alpha_list = []
26        self.mess_list = []
27        self.dataset = []
[038c00cf]28
[d84a90c]29    def is_odd(self, n):
30        """
31        """
[1db4a53]32        return bool(n % 2)
[e96a852]33
[d84a90c]34    def sort_osc(self):
35        """
36        """
[1db4a53]37        #import copy
[d84a90c]38        osc = copy.deepcopy(self.dataset)
39        lis = []
40        for i in range(len(osc)):
41            osc.sort()
42            re = osc.pop(0)
43            lis.append(re)
44        return lis
[038c00cf]45
[d84a90c]46    def median_osc(self):
47        """
48        """
49        osc = self.sort_osc()
50        dv = len(osc)
51        med = float(dv) / 2.0
52        odd = self.is_odd(dv)
53        medi = 0
54        for i in range(dv):
55            if odd == True:
56                medi = osc[int(med)]
57            else:
58                medi = osc[int(med) - 1]
59        return medi
[e96a852]60
[d84a90c]61    def get0_out(self):
62        """
[34f3ad0]63        """
[e96a852]64        inver = self.invertor
65        self.osc_list = []
66        self.err_list = []
67        self.alpha_list = []
[7578961]68        for k in range(self.nterm_min, self.nterm_max, 1):
[e96a852]69            if self.isquit_func != None:
70                self.isquit_func()
[34f3ad0]71            best_alpha, message, _ = inver.estimate_alpha(k)
[e96a852]72            inver.alpha = best_alpha
73            inver.out, inver.cov = inver.lstsq(k)
74            osc = inver.oscillations(inver.out)
75            err = inver.get_pos_err(inver.out, inver.cov)
[1db4a53]76            if osc > 10.0:
[7578961]77                break
[e96a852]78            self.osc_list.append(osc)
79            self.err_list.append(err)
80            self.alpha_list.append(inver.alpha)
81            self.mess_list.append(message)
[038c00cf]82
[e96a852]83        new_osc1 = []
[1db4a53]84        new_osc2 = []
[e96a852]85        new_osc3 = []
[1db4a53]86        flag9 = False
87        flag8 = False
[e96a852]88        for i in range(len(self.err_list)):
[1db4a53]89            if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9:
[e96a852]90                new_osc1.append(self.osc_list[i])
[1db4a53]91                flag9 = True
92            if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8:
[e96a852]93                new_osc2.append(self.osc_list[i])
[1db4a53]94                flag8 = True
95            if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7:
[e96a852]96                new_osc3.append(self.osc_list[i])
[038c00cf]97
[1db4a53]98        if flag9 == True:
[e96a852]99            self.dataset = new_osc1
[1db4a53]100        elif flag8 == True:
[e96a852]101            self.dataset = new_osc2
102        else:
103            self.dataset = new_osc3
[038c00cf]104
[e96a852]105        return self.dataset
[038c00cf]106
[d84a90c]107    def ls_osc(self):
108        """
109        """
[e96a852]110        # Generate data
[bf6b8d1]111        self.get0_out()
[e96a852]112        med = self.median_osc()
[038c00cf]113
[e96a852]114        #TODO: check 1
115        ls_osc = self.dataset
116        ls = []
117        for i in range(len(ls_osc)):
118            if int(med) == int(ls_osc[i]):
119                ls.append(ls_osc[i])
120        return ls
121
[d84a90c]122    def compare_err(self):
123        """
124        """
[e96a852]125        ls = self.ls_osc()
126        nt_ls = []
127        for i in range(len(ls)):
128            r = ls[i]
129            n = self.osc_list.index(r) + 10
130            nt_ls.append(n)
131        return nt_ls
132
[d84a90c]133    def num_terms(self, isquit_func=None):
134        """
135        """
136        try:
137            self.isquit_func = isquit_func
138            nts = self.compare_err()
139            div = len(nts)
[038c00cf]140            tem = float(div) / 2.0
[d84a90c]141            odd = self.is_odd(div)
142            if odd == True:
143                nt = nts[int(tem)]
144            else:
145                nt = nts[int(tem) - 1]
[038c00cf]146            return nt, self.alpha_list[nt - 10], self.mess_list[nt - 10]
[d84a90c]147        except:
[c1bffa5]148            #TODO: check the logic above and make sure it doesn't
149            # rely on the try-except.
[bf6b8d1]150            return self.nterm_min, self.invertor.alpha, ''
[e96a852]151
[34f3ad0]152
[e96a852]153#For testing
154def load(path):
[d84a90c]155    # Read the data from the data file
[9a5097c]156    data_x = np.zeros(0)
157    data_y = np.zeros(0)
158    data_err = np.zeros(0)
[038c00cf]159    scale = None
160    min_err = 0.0
[d84a90c]161    if not path == None:
[038c00cf]162        input_f = open(path, 'r')
163        buff = input_f.read()
164        lines = buff.split('\n')
[d84a90c]165        for line in lines:
166            try:
167                toks = line.split()
[1db4a53]168                test_x = float(toks[0])
169                test_y = float(toks[1])
170                if len(toks) > 2:
[d84a90c]171                    err = float(toks[2])
172                else:
[1db4a53]173                    if scale == None:
174                        scale = 0.05 * math.sqrt(test_y)
[d84a90c]175                        #scale = 0.05/math.sqrt(y)
[34f3ad0]176                        min_err = 0.01 * y
[1db4a53]177                    err = scale * math.sqrt(test_y) + min_err
[d84a90c]178                    #err = 0
[038c00cf]179
[9a5097c]180                data_x = np.append(data_x, test_x)
181                data_y = np.append(data_y, test_y)
182                data_err = np.append(data_err, err)
[d84a90c]183            except:
[c155a16]184                logger.error(sys.exc_value)
[038c00cf]185
[d84a90c]186    return data_x, data_y, data_err
187
[e96a852]188
189if __name__ == "__main__":
[5f8fc78]190    invert = Invertor()
[e96a852]191    x, y, erro = load("test/Cyl_A_D102.txt")
[5f8fc78]192    invert.d_max = 102.0
193    invert.nfunc = 10
194    invert.x = x
195    invert.y = y
196    invert.err = erro
[e96a852]197    # Testing estimator
[5f8fc78]198    est = NTermEstimator(invert)
[e96a852]199    print est.num_terms()
Note: See TracBrowser for help on using the repository browser.