source: sasview/src/sas/sascalc/calculator/BaseComponent.py @ b9d74f3

Last change on this file since b9d74f3 was b9d74f3, checked in by andyfaff, 8 years ago

MAINT: use raise Exception() not raise Exception

  • Property mode set to 100644
File size: 9.0 KB
Line 
1#!/usr/bin/env python
2
3"""
4Provide base functionality for all model components
5"""
6
7# imports
8import copy
9from collections import OrderedDict
10
11import numpy as np
12#TO DO: that about a way to make the parameter
13#is self return if it is fittable or not
14
15class BaseComponent:
16    """
17    Basic model component
18
19    Since version 0.5.0, basic operations are no longer supported.
20    """
21
22    def __init__(self):
23        """ Initialization"""
24
25        ## Name of the model
26        self.name = "BaseComponent"
27
28        ## Parameters to be accessed by client
29        self.params = {}
30        self.details = {}
31        ## Dictionary used to store the dispersity/averaging
32        #  parameters of dispersed/averaged parameters.
33        self.dispersion = {}
34        # string containing information about the model such as the equation
35        #of the given model, exception or possible use
36        self.description = ''
37        #list of parameter that can be fitted
38        self.fixed = []
39        #list of non-fittable parameter
40        self.non_fittable = []
41        ## parameters with orientation
42        self.orientation_params = []
43        ## magnetic parameters
44        self.magnetic_params = []
45        ## store dispersity reference
46        self._persistency_dict = {}
47        ## independent parameter name and unit [string]
48        self.input_name = "Q"
49        self.input_unit = "A^{-1}"
50        ## output name and unit  [string]
51        self.output_name = "Intensity"
52        self.output_unit = "cm^{-1}"
53
54        self.is_multiplicity_model = False
55        self.is_structure_factor = False
56        self.is_form_factor = False
57
58    def __str__(self):
59        """
60        :return: string representatio
61        """
62        return self.name
63
64    def is_fittable(self, par_name):
65        """
66        Check if a given parameter is fittable or not
67
68        :param par_name: the parameter name to check
69
70        """
71        return par_name.lower() in self.fixed
72        #For the future
73        #return self.params[str(par_name)].is_fittable()
74
75    def run(self, x):
76        """
77        run 1d
78        """
79        return NotImplemented
80
81    def runXY(self, x):
82        """
83        run 2d
84        """
85        return NotImplemented
86
87    def calculate_ER(self):
88        """
89        Calculate effective radius
90        """
91        return NotImplemented
92
93    def calculate_VR(self):
94        """
95        Calculate volume fraction ratio
96        """
97        return NotImplemented
98
99    def evalDistribution(self, qdist):
100        """
101        Evaluate a distribution of q-values.
102
103        * For 1D, a numpy array is expected as input: ::
104
105            evalDistribution(q)
106
107          where q is a numpy array.
108
109
110        * For 2D, a list of numpy arrays are expected: [qx_prime,qy_prime],
111          where 1D arrays, ::
112
113              qx_prime = [ qx[0], qx[1], qx[2], ....]
114
115          and ::
116
117              qy_prime = [ qy[0], qy[1], qy[2], ....]
118
119        Then get ::
120
121            q = np.sqrt(qx_prime^2+qy_prime^2)
122
123        that is a qr in 1D array; ::
124
125            q = [q[0], q[1], q[2], ....]
126
127        .. note:: Due to 2D speed issue, no anisotropic scattering
128                  is supported for python models, thus C-models should have
129                  their own evalDistribution methods.
130
131        The method is then called the following way: ::
132
133            evalDistribution(q)
134
135        where q is a numpy array.
136
137        :param qdist: ndarray of scalar q-values or list [qx,qy] where qx,qy are 1D ndarrays
138        """
139        if qdist.__class__.__name__ == 'list':
140            # Check whether we have a list of ndarrays [qx,qy]
141            if len(qdist)!=2 or \
142                qdist[0].__class__.__name__ != 'ndarray' or \
143                qdist[1].__class__.__name__ != 'ndarray':
144                msg = "evalDistribution expects a list of 2 ndarrays"
145                raise RuntimeError(msg)
146
147            # Extract qx and qy for code clarity
148            qx = qdist[0]
149            qy = qdist[1]
150
151            # calculate q_r component for 2D isotropic
152            q = np.sqrt(qx**2+qy**2)
153            # vectorize the model function runXY
154            v_model = np.vectorize(self.runXY, otypes=[float])
155            # calculate the scattering
156            iq_array = v_model(q)
157
158            return iq_array
159
160        elif qdist.__class__.__name__ == 'ndarray':
161            # We have a simple 1D distribution of q-values
162            v_model = np.vectorize(self.runXY, otypes=[float])
163            iq_array = v_model(qdist)
164            return iq_array
165
166        else:
167            mesg = "evalDistribution is expecting an ndarray of scalar q-values"
168            mesg += " or a list [qx,qy] where qx,qy are 2D ndarrays."
169            raise RuntimeError(mesg)
170
171
172
173    def clone(self):
174        """ Returns a new object identical to the current object """
175        obj = copy.deepcopy(self)
176        return self._clone(obj)
177
178    def _clone(self, obj):
179        """
180        Internal utility function to copy the internal
181        data members to a fresh copy.
182        """
183        obj.params     = copy.deepcopy(self.params)
184        obj.details    = copy.deepcopy(self.details)
185        obj.dispersion = copy.deepcopy(self.dispersion)
186        obj._persistency_dict = copy.deepcopy( self._persistency_dict)
187        return obj
188
189    def set_dispersion(self, parameter, dispersion):
190        """
191        model dispersions
192        """
193        ##Not Implemented
194        return None
195
196    def getProfile(self):
197        """
198        Get SLD profile
199
200        : return: (z, beta) where z is a list of depth of the transition points
201                beta is a list of the corresponding SLD values
202        """
203        #Not Implemented
204        return None, None
205
206    def setParam(self, name, value):
207        """
208        Set the value of a model parameter
209
210        :param name: name of the parameter
211        :param value: value of the parameter
212
213        """
214        # Look for dispersion parameters
215        toks = name.split('.')
216        if len(toks)==2:
217            for item in self.dispersion.keys():
218                if item.lower()==toks[0].lower():
219                    for par in self.dispersion[item]:
220                        if par.lower() == toks[1].lower():
221                            self.dispersion[item][par] = value
222                            return
223        else:
224            # Look for standard parameter
225            for item in self.params.keys():
226                if item.lower()==name.lower():
227                    self.params[item] = value
228                    return
229
230        raise ValueError("Model does not contain parameter %s" % name)
231
232    def getParam(self, name):
233        """
234        Set the value of a model parameter
235        :param name: name of the parameter
236
237        """
238        # Look for dispersion parameters
239        toks = name.split('.')
240        if len(toks)==2:
241            for item in self.dispersion.keys():
242                if item.lower()==toks[0].lower():
243                    for par in self.dispersion[item]:
244                        if par.lower() == toks[1].lower():
245                            return self.dispersion[item][par]
246        else:
247            # Look for standard parameter
248            for item in self.params.keys():
249                if item.lower()==name.lower():
250                    return self.params[item]
251
252        raise ValueError("Model does not contain parameter %s" % name)
253
254    def getParamList(self):
255        """
256        Return a list of all available parameters for the model
257        """
258        list = _ordered_keys(self.params)
259        # WARNING: Extending the list with the dispersion parameters
260        list.extend(self.getDispParamList())
261        return list
262
263    def getDispParamList(self):
264        """
265        Return a list of all available parameters for the model
266        """
267        list = []
268        for item in _ordered_keys(self.dispersion):
269            for p in _ordered_keys(self.dispersion[item]):
270                if p not in ['type']:
271                    list.append('%s.%s' % (item.lower(), p.lower()))
272
273        return list
274
275    # Old-style methods that are no longer used
276    def setParamWithToken(self, name, value, token, member):
277        """
278        set Param With Token
279        """
280        return NotImplemented
281    def getParamWithToken(self, name, token, member):
282        """
283        get Param With Token
284        """
285        return NotImplemented
286
287    def getParamListWithToken(self, token, member):
288        """
289        get Param List With Token
290        """
291        return NotImplemented
292    def __add__(self, other):
293        """
294        add
295        """
296        raise ValueError("Model operation are no longer supported")
297    def __sub__(self, other):
298        """
299        sub
300        """
301        raise ValueError("Model operation are no longer supported")
302    def __mul__(self, other):
303        """
304        mul
305        """
306        raise ValueError("Model operation are no longer supported")
307    def __div__(self, other):
308        """
309        div
310        """
311        raise ValueError("Model operation are no longer supported")
312
313
314def _ordered_keys(d):
315    keys = list(d.keys())
316    if not isinstance(d, OrderedDict):
317        keys.sort()
318    return keys
Note: See TracBrowser for help on using the repository browser.