source: sasview/src/sas/sascalc/pr/num_term.py @ 489bb46

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since 489bb46 was b699768, checked in by Piotr Rozyczko <piotr.rozyczko@…>, 9 years ago

Initial commit of the refactored SasCalc? module.

  • Property mode set to 100644
File size: 5.5 KB
RevLine 
[1db4a53]1import math
2import numpy
3import copy
[5f8fc78]4import sys
5import logging
[b699768]6from sas.sascalc.pr.invertor import Invertor
[e96a852]7
[5f8fc78]8class NTermEstimator(object):
[d84a90c]9    """
10    """
11    def __init__(self, invertor):
12        """
13        """
14        self.invertor = invertor
15        self.nterm_min = 10
16        self.nterm_max = len(self.invertor.x)
[1db4a53]17        if self.nterm_max > 50:
18            self.nterm_max = 50
[d84a90c]19        self.isquit_func = None
[038c00cf]20
[d84a90c]21        self.osc_list = []
22        self.err_list = []
23        self.alpha_list = []
24        self.mess_list = []
25        self.dataset = []
[038c00cf]26
[d84a90c]27    def is_odd(self, n):
28        """
29        """
[1db4a53]30        return bool(n % 2)
[e96a852]31
[d84a90c]32    def sort_osc(self):
33        """
34        """
[1db4a53]35        #import copy
[d84a90c]36        osc = copy.deepcopy(self.dataset)
37        lis = []
38        for i in range(len(osc)):
39            osc.sort()
40            re = osc.pop(0)
41            lis.append(re)
42        return lis
[038c00cf]43
[d84a90c]44    def median_osc(self):
45        """
46        """
47        osc = self.sort_osc()
48        dv = len(osc)
49        med = float(dv) / 2.0
50        odd = self.is_odd(dv)
51        medi = 0
52        for i in range(dv):
53            if odd == True:
54                medi = osc[int(med)]
55            else:
56                medi = osc[int(med) - 1]
57        return medi
[e96a852]58
[d84a90c]59    def get0_out(self):
60        """
[34f3ad0]61        """
[e96a852]62        inver = self.invertor
63        self.osc_list = []
64        self.err_list = []
65        self.alpha_list = []
[7578961]66        for k in range(self.nterm_min, self.nterm_max, 1):
[e96a852]67            if self.isquit_func != None:
68                self.isquit_func()
[34f3ad0]69            best_alpha, message, _ = inver.estimate_alpha(k)
[e96a852]70            inver.alpha = best_alpha
71            inver.out, inver.cov = inver.lstsq(k)
72            osc = inver.oscillations(inver.out)
73            err = inver.get_pos_err(inver.out, inver.cov)
[1db4a53]74            if osc > 10.0:
[7578961]75                break
[e96a852]76            self.osc_list.append(osc)
77            self.err_list.append(err)
78            self.alpha_list.append(inver.alpha)
79            self.mess_list.append(message)
[038c00cf]80
[e96a852]81        new_osc1 = []
[1db4a53]82        new_osc2 = []
[e96a852]83        new_osc3 = []
[1db4a53]84        flag9 = False
85        flag8 = False
[e96a852]86        for i in range(len(self.err_list)):
[1db4a53]87            if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9:
[e96a852]88                new_osc1.append(self.osc_list[i])
[1db4a53]89                flag9 = True
90            if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8:
[e96a852]91                new_osc2.append(self.osc_list[i])
[1db4a53]92                flag8 = True
93            if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7:
[e96a852]94                new_osc3.append(self.osc_list[i])
[038c00cf]95
[1db4a53]96        if flag9 == True:
[e96a852]97            self.dataset = new_osc1
[1db4a53]98        elif flag8 == True:
[e96a852]99            self.dataset = new_osc2
100        else:
101            self.dataset = new_osc3
[038c00cf]102
[e96a852]103        return self.dataset
[038c00cf]104
[d84a90c]105    def ls_osc(self):
106        """
107        """
[e96a852]108        # Generate data
[bf6b8d1]109        self.get0_out()
[e96a852]110        med = self.median_osc()
[038c00cf]111
[e96a852]112        #TODO: check 1
113        ls_osc = self.dataset
114        ls = []
115        for i in range(len(ls_osc)):
116            if int(med) == int(ls_osc[i]):
117                ls.append(ls_osc[i])
118        return ls
119
[d84a90c]120    def compare_err(self):
121        """
122        """
[e96a852]123        ls = self.ls_osc()
124        nt_ls = []
125        for i in range(len(ls)):
126            r = ls[i]
127            n = self.osc_list.index(r) + 10
128            nt_ls.append(n)
129        return nt_ls
130
[d84a90c]131    def num_terms(self, isquit_func=None):
132        """
133        """
134        try:
135            self.isquit_func = isquit_func
136            nts = self.compare_err()
137            div = len(nts)
[038c00cf]138            tem = float(div) / 2.0
[d84a90c]139            odd = self.is_odd(div)
140            if odd == True:
141                nt = nts[int(tem)]
142            else:
143                nt = nts[int(tem) - 1]
[038c00cf]144            return nt, self.alpha_list[nt - 10], self.mess_list[nt - 10]
[d84a90c]145        except:
[c1bffa5]146            #TODO: check the logic above and make sure it doesn't
147            # rely on the try-except.
[bf6b8d1]148            return self.nterm_min, self.invertor.alpha, ''
[e96a852]149
[34f3ad0]150
[e96a852]151#For testing
152def load(path):
[d84a90c]153    # Read the data from the data file
[038c00cf]154    data_x = numpy.zeros(0)
155    data_y = numpy.zeros(0)
[d84a90c]156    data_err = numpy.zeros(0)
[038c00cf]157    scale = None
158    min_err = 0.0
[d84a90c]159    if not path == None:
[038c00cf]160        input_f = open(path, 'r')
161        buff = input_f.read()
162        lines = buff.split('\n')
[d84a90c]163        for line in lines:
164            try:
165                toks = line.split()
[1db4a53]166                test_x = float(toks[0])
167                test_y = float(toks[1])
168                if len(toks) > 2:
[d84a90c]169                    err = float(toks[2])
170                else:
[1db4a53]171                    if scale == None:
172                        scale = 0.05 * math.sqrt(test_y)
[d84a90c]173                        #scale = 0.05/math.sqrt(y)
[34f3ad0]174                        min_err = 0.01 * y
[1db4a53]175                    err = scale * math.sqrt(test_y) + min_err
[d84a90c]176                    #err = 0
[038c00cf]177
[1db4a53]178                data_x = numpy.append(data_x, test_x)
179                data_y = numpy.append(data_y, test_y)
[d84a90c]180                data_err = numpy.append(data_err, err)
181            except:
[5f8fc78]182                logging.error(sys.exc_value)
[038c00cf]183
[d84a90c]184    return data_x, data_y, data_err
185
[e96a852]186
187if __name__ == "__main__":
[5f8fc78]188    invert = Invertor()
[e96a852]189    x, y, erro = load("test/Cyl_A_D102.txt")
[5f8fc78]190    invert.d_max = 102.0
191    invert.nfunc = 10
192    invert.x = x
193    invert.y = y
194    invert.err = erro
[e96a852]195    # Testing estimator
[5f8fc78]196    est = NTermEstimator(invert)
[e96a852]197    print est.num_terms()
Note: See TracBrowser for help on using the repository browser.