source: sasmodels/sasmodels/sasview_model.py @ 32c160a

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 32c160a was 32c160a, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

support ER/VR python kernels; move metadata to python

  • Property mode set to 100644
File size: 11.2 KB
Line 
1import math
2from copy import deepcopy
3
4import numpy as np
5
6def make_class(kernel_module, dtype='single'):
7    from .core import opencl_model
8    model =  opencl_model(kernel_module, dtype=dtype)
9    def __init__(self, multfactor=1):
10        SasviewModel.__init__(self, self.kernel)
11    attrs = dict(__init__=__init__, kernel=model)
12    ConstructedModel = type(model.info['name'],  (SasviewModel,), attrs)
13    #class ConstructedModel(SasviewModel):
14    #    kernel = model
15    #    def __init__(self, multfactor=1):
16    #        SasviewModel.__init__(self, self.kernel)
17    #ConstructedModel.__name__ = model.info['name']
18    return ConstructedModel
19
20class SasviewModel(object):
21    """
22    Sasview wrapper for opencl/ctypes model.
23    """
24    def __init__(self, model):
25        """Initialization"""
26        self._model = model
27
28        self.name = model.info['name']
29        self.description = model.info['description']
30        self.category = None
31        self.multiplicity_info = None
32        self.is_multifunc = False
33
34        ## interpret the parameters
35        ## TODO: reorganize parameter handling
36        self.details = dict()
37        self.params = dict()
38        self.dispersion = dict()
39        partype = model.info['partype']
40        for name,units,default,limits,ptype,description in model.info['parameters']:
41            self.params[name] = default
42            self.details[name] = [units]+limits
43
44        for name in partype['pd-2d']:
45            self.dispersion[name] = {
46                'width': 0,
47                'npts': 35,
48                'nsigmas': 3,
49                'type': 'gaussian',
50            }
51
52        self.orientation_params = (
53            partype['orientation']
54            + [n+'.width' for n in partype['orientation']]
55            + partype['magnetic'])
56        self.magnetic_params = partype['magnetic']
57        self.fixed = [n+'.width' for n in partype['pd-2d']]
58        self.non_fittable = []
59
60        ## independent parameter name and unit [string]
61        self.input_name = model.info.get("input_name","Q")
62        self.input_unit = model.info.get("input_unit","A^{-1}")
63        self.output_name = model.info.get("output_name","Intensity")
64        self.output_unit = model.info.get("output_unit","cm^{-1}")
65
66        ## _persistency_dict is used by sans.perspectives.fitting.basepage
67        ## to store dispersity reference.
68        ## TODO: _persistency_dict to persistency_dict throughout sasview
69        self._persistency_dict = {}
70
71        ## New fields introduced for opencl rewrite
72        self.cutoff = 1e-5
73
74    def __str__(self):
75        """
76        :return: string representation
77        """
78        return self.name
79
80    def is_fittable(self, par_name):
81        """
82        Check if a given parameter is fittable or not
83
84        :param par_name: the parameter name to check
85        """
86        return par_name.lower() in self.fixed
87        #For the future
88        #return self.params[str(par_name)].is_fittable()
89
90
91    def getProfile(self):
92        """
93        Get SLD profile
94
95        : return: (z, beta) where z is a list of depth of the transition points
96                beta is a list of the corresponding SLD values
97        """
98        return None, None
99
100    def setParam(self, name, value):
101        """
102        Set the value of a model parameter
103
104        :param name: name of the parameter
105        :param value: value of the parameter
106
107        """
108        # Look for dispersion parameters
109        toks = name.split('.')
110        if len(toks)==2:
111            for item in self.dispersion.keys():
112                if item.lower()==toks[0].lower():
113                    for par in self.dispersion[item]:
114                        if par.lower() == toks[1].lower():
115                            self.dispersion[item][par] = value
116                            return
117        else:
118            # Look for standard parameter
119            for item in self.params.keys():
120                if item.lower()==name.lower():
121                    self.params[item] = value
122                    return
123
124        raise ValueError, "Model does not contain parameter %s" % name
125
126    def getParam(self, name):
127        """
128        Set the value of a model parameter
129
130        :param name: name of the parameter
131
132        """
133        # Look for dispersion parameters
134        toks = name.split('.')
135        if len(toks)==2:
136            for item in self.dispersion.keys():
137                if item.lower()==toks[0].lower():
138                    for par in self.dispersion[item]:
139                        if par.lower() == toks[1].lower():
140                            return self.dispersion[item][par]
141        else:
142            # Look for standard parameter
143            for item in self.params.keys():
144                if item.lower()==name.lower():
145                    return self.params[item]
146
147        raise ValueError, "Model does not contain parameter %s" % name
148
149    def getParamList(self):
150        """
151        Return a list of all available parameters for the model
152        """
153        list = self.params.keys()
154        # WARNING: Extending the list with the dispersion parameters
155        list.extend(self.getDispParamList())
156        return list
157
158    def getDispParamList(self):
159        """
160        Return a list of all available parameters for the model
161        """
162        # TODO: fix test so that parameter order doesn't matter
163        ret = ['%s.%s'%(d.lower(), p)
164               for d in self._model.info['partype']['pd-2d']
165               for p in ('npts', 'nsigmas', 'width')]
166        #print ret
167        return ret
168
169    def clone(self):
170        """ Return a identical copy of self """
171        return deepcopy(self)
172
173    def run(self, x=0.0):
174        """
175        Evaluate the model
176
177        :param x: input q, or [q,phi]
178
179        :return: scattering function P(q)
180
181        **DEPRECATED**: use calculate_Iq instead
182        """
183        if isinstance(x, (list,tuple)):
184            q, phi = x
185            return self.calculate_Iq([q * math.cos(phi)],
186                                     [q * math.sin(phi)])[0]
187        else:
188            return self.calculate_Iq([float(x)])[0]
189
190
191    def runXY(self, x=0.0):
192        """
193        Evaluate the model in cartesian coordinates
194
195        :param x: input q, or [qx, qy]
196
197        :return: scattering function P(q)
198
199        **DEPRECATED**: use calculate_Iq instead
200        """
201        if isinstance(x, (list,tuple)):
202            return self.calculate_Iq([float(x[0])],[float(x[1])])[0]
203        else:
204            return self.calculate_Iq([float(x)])[0]
205
206    def evalDistribution(self, qdist):
207        """
208        Evaluate a distribution of q-values.
209
210        * For 1D, a numpy array is expected as input: ::
211
212            evalDistribution(q)
213
214          where q is a numpy array.
215
216        * For 2D, a list of numpy arrays are expected: [qx,qy],
217          with 1D arrays::
218
219              qx = [ qx[0], qx[1], qx[2], ....]
220
221          and::
222
223              qy = [ qy[0], qy[1], qy[2], ....]
224
225        Then get ::
226
227            q = numpy.sqrt(qx^2+qy^2)
228
229        that is a qr in 1D array::
230
231            q = [q[0], q[1], q[2], ....]
232
233
234        :param qdist: ndarray of scalar q-values or list [qx,qy] where qx,qy are 1D ndarrays
235        """
236        if isinstance(qdist, (list,tuple)):
237            # Check whether we have a list of ndarrays [qx,qy]
238            qx, qy = qdist
239            return self.calculate_Iq(qx, qy)
240
241        elif isinstance(qdist, np.ndarray):
242            # We have a simple 1D distribution of q-values
243            return self.calculate_Iq(qdist)
244
245        else:
246            raise TypeError("evalDistribution expects q or [qx, qy], not %r"%type(qdist))
247
248    def calculate_Iq(self, *args):
249        q_vectors = [np.asarray(q) for q in args]
250        fn = self._model(self._model.make_input(q_vectors))
251        pars = [self.params[v] for v in fn.fixed_pars]
252        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
253        result = fn(pars, pd_pars, self.cutoff)
254        fn.input.release()
255        fn.release()
256        return result
257
258    def calculate_ER(self):
259        """
260        Calculate the effective radius for P(q)*S(q)
261
262        :return: the value of the effective radius
263        """
264        ER = self._model.info.get('ER', None)
265        if ER is None:
266            return 1.0
267        else:
268            vol_pars = self._model.info['partype']['volume']
269            values, weights = self._dispersion_mesh(vol_pars)
270            fv = ER(*values)
271            #print values[0].shape, weights.shape, fv.shape
272            return np.sum(weights*fv) / np.sum(weights)
273
274    def calculate_VR(self):
275        """
276        Calculate the volf ratio for P(q)*S(q)
277
278        :return: the value of the volf ratio
279        """
280        VR = self._model.info.get('VR', None)
281        if VR is None:
282            return 1.0
283        else:
284            vol_pars = self._model.info['partype']['volume']
285            values, weights = self._dispersion_mesh(vol_pars)
286            whole,part = VR(*values)
287            return np.sum(weights*part)/np.sum(weights*whole)
288
289    def set_dispersion(self, parameter, dispersion):
290        """
291        Set the dispersion object for a model parameter
292
293        :param parameter: name of the parameter [string]
294        :param dispersion: dispersion object of type Dispersion
295        """
296        if parameter.lower() in (s.lower() for s in self.params.keys()):
297            # TODO: Store the disperser object directly in the model.
298            # The current method of creating one on the fly whenever it is
299            # needed is kind of funky.
300            # Note: can't seem to get disperser parameters from sasview
301            # (1) Could create a sasview model that has not yet # been
302            # converted, assign the disperser to one of its polydisperse
303            # parameters, then retrieve the disperser parameters from the
304            # sasview model.  (2) Could write a disperser parameter retriever
305            # in sasview.  (3) Could modify sasview to use sasmodels.weights
306            # dispersers.
307            # For now, rely on the fact that the sasview only ever uses
308            # new dispersers in the set_dispersion call and create a new
309            # one instead of trying to assign parameters.
310            from . import weights
311            disperser = weights.dispersers[dispersion.__class__.__name__]
312            dispersion = weights.models[disperser]()
313            self.dispersion[parameter] = dispersion.get_pars()
314        else:
315            raise ValueError("%r is not a dispersity or orientation parameter")
316
317    def _dispersion_mesh(self, pars):
318        """
319        Create a mesh grid of dispersion parameters and weights.
320
321        Returns [p1,p2,...],w where pj is a vector of values for parameter j
322        and w is a vector containing the products for weights for each
323        parameter set in the vector.
324        """
325        values, weights = zip(*[self._get_weights(p) for p in pars])
326        values = [v.flatten() for v in np.meshgrid(*values)]
327        weights = np.vstack([v.flatten() for v in np.meshgrid(*weights)])
328        weights = np.prod(weights, axis=0)
329        return values, weights
330
331    def _get_weights(self, par):
332        from . import weights
333
334        relative = self._model.info['partype']['pd-rel']
335        limits = self._model.info['limits']
336        dis = self.dispersion[par]
337        v,w = weights.get_weights(
338            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
339            self.params[par], limits[par], par in relative)
340        return v,w/w.max()
341
Note: See TracBrowser for help on using the repository browser.