source: sasview/sansmodels/test/SmearList.py @ b4293d2

ESS_GUIESS_GUI_DocsESS_GUI_batch_fittingESS_GUI_bumps_abstractionESS_GUI_iss1116ESS_GUI_iss879ESS_GUI_iss959ESS_GUI_openclESS_GUI_orderingESS_GUI_sync_sascalccostrafo411magnetic_scattrelease-4.1.1release-4.1.2release-4.2.2release_4.0.1ticket-1009ticket-1094-headlessticket-1242-2d-resolutionticket-1243ticket-1249ticket885unittest-saveload
Last change on this file since b4293d2 was 0abf7bf, checked in by Mathieu Doucet <doucetm@…>, 13 years ago

Re #5 Get rid of pyre models, odb creation, and modelfactory, all of which haven't been used since 2007.

  • Property mode set to 100644
File size: 5.1 KB
Line 
1
2from sans.models.BaseComponent import BaseComponent
3import math
4
5class Smear:
6   
7    def __init__(self, model, param, sigma):
8        """
9            @param model: model to smear
10            @param param: parameter to smear
11            @param sigma: std deviations for parameter
12        """
13
14       
15        ## Model to smear
16        self.model = model
17        ## Standard deviation of the smearing
18        self.sigmas = sigma
19        ## Parameter to smear
20        self.params = param
21        ## Nominal value of the smeared parameter
22        self.centers = []
23        ## Error on last evaluation
24        self.error = 0.0
25        ## Error flag
26        self.doErrors = False
27        for par in self.params:
28            self.centers.append(self.model.getParam(par))
29       
30    def smearParam(self, id, x):
31        """
32            @param x: input value
33        """
34        # from random import random
35        # If we exhausted the parameter array, simply evaluate
36        # the model
37        if id < len(self.params):
38            #print "smearing", self.params[id]
39           
40            # Average over Gaussian distribution (2 sigmas)
41            value_sum = 0.0
42            gauss_sum = 0.0
43           
44            min_value = self.centers[id] - 2*self.sigmas[id]
45            max_value = self.centers[id] + 2*self.sigmas[id]
46            n_pts = 25
47            step = (max_value - min_value)/(n_pts-1)
48            #print min_value, max_value, step, max_value-min_value
49            if step == 0.0:
50                return self.smearParam(id+1,x)
51           
52            # Gaussian function used to weigh points
53            gaussian = Gaussian()
54            gaussian.setParam('sigma', self.sigmas[id])
55            gaussian.setParam('mean', self.centers[id])
56                   
57            # Compute average
58            prev_value = None
59            error_sys = 0.0
60            for i in range(n_pts):
61                # Set the parameter value           
62                value = min_value + i*step
63                # value = random()*4.0*self.sigmas[id] + min_value
64                # print value
65                self.model.setParam(self.params[id], value)
66                gauss_value = gaussian.run(value)
67                #gauss_value = 1.0
68                #if id==0: print value, gauss_value
69                func_value, error_1 = self.smearParam(id+1, x)
70                if self.doErrors:
71                    if not prev_value == None:
72                        error_sys += (func_value-prev_value)*(func_value-prev_value)/4.0
73                    prev_value = func_value
74
75                value_sum += gauss_value * func_value
76                gauss_sum += gauss_value
77               
78            #print "Error", math.sqrt(error)
79            return value_sum/gauss_sum, math.sqrt(error_sys)
80       
81        else:
82            return self.model.run(x), 0.0
83       
84    def run(self, x):
85        """
86            @param x: input
87        """
88       
89        # Go through the list of parameters
90        n_par = len(self.params)
91       
92        # Check array lengths
93        if not len(self.centers) == n_par or\
94            not len(self.sigmas) == n_par:
95            raise ValueError, "Incompatible array lengths"
96       
97        # Smear first parameters
98        if n_par > 0:
99            value, error = self.smearParam(0, x)
100            self.error = error
101           
102            # Put back original values
103            for i in range(len(self.centers)):
104                self.model.setParam(self.params[i], self.centers[i])
105           
106           
107            return value
108       
109class Gaussian(BaseComponent):
110    """ Gaussian function """
111   
112    def __init__(self):
113        """ Initialization """
114        BaseComponent.__init__(self)
115       
116        ## Name of the model
117        self.name = "Gaussian"
118        ## Scale factor
119        self.params['scale']  = 1.0
120        ## Mean value
121        self.params['mean']   = 0.0
122        ## Standard deviation
123        self.params['sigma']  = 1.0
124        ## Internal log
125        self.log = {}
126        return
127   
128    def run(self, x=0):
129        """ Evaluate the function
130            @param x: input q or [q,phi]
131            @return: scattering function
132        """
133        if(type(x)==type([])):
134            # vector input
135            if(len(x)==2):
136                return self.analytical_2D(x)
137            else:
138                raise ValueError, "Gaussian takes a scalar or a 2D point"
139        else:
140            return self.analytical_1D(x)
141
142    def analytical_2D(self, x):
143        """ Evaluate 2D model
144            @param x: input [q,phi]
145            @return: scattering function
146        """
147       
148        # 2D sphere is the same as 1D sphere
149        return self.analytical_1D(x[0])
150
151    def analytical_1D(self, x):
152        """ Evaluate 1D model
153            @param x: input q
154            @return: scattering function
155        """
156        vary = x-self.params['mean']
157        expo_value = -vary*vary/(2*self.params['sigma']*self.params['sigma'])
158        value = self.params['scale'] *  math.exp(expo_value)
159       
160        return value
Note: See TracBrowser for help on using the repository browser.