source: sasmodels/sasmodels/sasview_model.py @ 1780d59

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1780d59 was 1780d59, checked in by Paul Kienzle <pkienzle@…>, 10 years ago

hack sasview model polydispersity

  • Property mode set to 100644
File size: 11.1 KB
Line 
1import math
2from copy import deepcopy
3
4import numpy as np
5
6def make_class(name, class_name=None, dtype='single'):
7    from .core import opencl_model
8    if class_name is None:
9        class_name = "".join(part.capitalize() for part in name.split('_')+['model'])
10    model =  opencl_model(name, dtype=dtype)
11    class ConstructedModel(SasviewModel):
12        kernel = model
13        def __init__(self, multfactor=1):
14            SasviewModel.__init__(self, self.kernel)
15    ConstructedModel.__name__ = class_name
16    return ConstructedModel
17
18class SasviewModel(object):
19    """
20    Sasview wrapper for opencl/ctypes model.
21    """
22    def __init__(self, model):
23        """Initialization"""
24        self._model = model
25
26        self.name = model.info['name']
27        self.description = model.info['description']
28        self.category = None
29        self.multiplicity_info = None
30        self.is_multifunc = False
31
32        ## interpret the parameters
33        ## TODO: reorganize parameter handling
34        self.details = dict()
35        self.params = dict()
36        self.dispersion = dict()
37        partype = model.info['partype']
38        for name,units,default,limits,ptype,description in model.info['parameters']:
39            self.params[name] = default
40            self.details[name] = [units]+limits
41
42        for name in partype['pd-2d']:
43            self.dispersion[name] = {
44                'width': 0,
45                'npts': 35,
46                'nsigmas': 3,
47                'type': 'gaussian',
48            }
49
50        self.orientation_params = (
51            partype['orientation']
52            + [n+'.width' for n in partype['orientation']]
53            + partype['magnetic'])
54        self.magnetic_params = partype['magnetic']
55        self.fixed = [n+'.width' for n in partype['pd-2d']]
56        self.non_fittable = []
57
58        ## independent parameter name and unit [string]
59        self.input_name = model.info.get("input_name","Q")
60        self.input_unit = model.info.get("input_unit","A^{-1}")
61        self.output_name = model.info.get("output_name","Intensity")
62        self.output_unit = model.info.get("output_unit","cm^{-1}")
63
64        ## _persistency_dict is used by sans.perspectives.fitting.basepage
65        ## to store dispersity reference.
66        ## TODO: _persistency_dict to persistency_dict throughout sasview
67        self._persistency_dict = {}
68
69        ## New fields introduced for opencl rewrite
70        self.cutoff = 1e-5
71
72    def __str__(self):
73        """
74        :return: string representation
75        """
76        return self.name
77
78    def is_fittable(self, par_name):
79        """
80        Check if a given parameter is fittable or not
81
82        :param par_name: the parameter name to check
83        """
84        return par_name.lower() in self.fixed
85        #For the future
86        #return self.params[str(par_name)].is_fittable()
87
88
89    def getProfile(self):
90        """
91        Get SLD profile
92
93        : return: (z, beta) where z is a list of depth of the transition points
94                beta is a list of the corresponding SLD values
95        """
96        return None, None
97
98    def setParam(self, name, value):
99        """
100        Set the value of a model parameter
101
102        :param name: name of the parameter
103        :param value: value of the parameter
104
105        """
106        # Look for dispersion parameters
107        toks = name.split('.')
108        if len(toks)==2:
109            for item in self.dispersion.keys():
110                if item.lower()==toks[0].lower():
111                    for par in self.dispersion[item]:
112                        if par.lower() == toks[1].lower():
113                            self.dispersion[item][par] = value
114                            return
115        else:
116            # Look for standard parameter
117            for item in self.params.keys():
118                if item.lower()==name.lower():
119                    self.params[item] = value
120                    return
121
122        raise ValueError, "Model does not contain parameter %s" % name
123
124    def getParam(self, name):
125        """
126        Set the value of a model parameter
127
128        :param name: name of the parameter
129
130        """
131        # Look for dispersion parameters
132        toks = name.split('.')
133        if len(toks)==2:
134            for item in self.dispersion.keys():
135                if item.lower()==toks[0].lower():
136                    for par in self.dispersion[item]:
137                        if par.lower() == toks[1].lower():
138                            return self.dispersion[item][par]
139        else:
140            # Look for standard parameter
141            for item in self.params.keys():
142                if item.lower()==name.lower():
143                    return self.params[item]
144
145        raise ValueError, "Model does not contain parameter %s" % name
146
147    def getParamList(self):
148        """
149        Return a list of all available parameters for the model
150        """
151        list = self.params.keys()
152        # WARNING: Extending the list with the dispersion parameters
153        list.extend(self.getDispParamList())
154        return list
155
156    def getDispParamList(self):
157        """
158        Return a list of all available parameters for the model
159        """
160        # TODO: fix test so that parameter order doesn't matter
161        ret = ['%s.%s'%(d.lower(), p)
162               for d in self._model.info['partype']['pd-2d']
163               for p in ('npts', 'nsigmas', 'width')]
164        #print ret
165        return ret
166
167    def clone(self):
168        """ Return a identical copy of self """
169        return deepcopy(self)
170
171    def run(self, x=0.0):
172        """
173        Evaluate the model
174
175        :param x: input q, or [q,phi]
176
177        :return: scattering function P(q)
178
179        **DEPRECATED**: use calculate_Iq instead
180        """
181        if isinstance(x, (list,tuple)):
182            q, phi = x
183            return self.calculate_Iq([q * math.cos(phi)],
184                                     [q * math.sin(phi)])[0]
185        else:
186            return self.calculate_Iq([float(x)])[0]
187
188
189    def runXY(self, x=0.0):
190        """
191        Evaluate the model in cartesian coordinates
192
193        :param x: input q, or [qx, qy]
194
195        :return: scattering function P(q)
196
197        **DEPRECATED**: use calculate_Iq instead
198        """
199        if isinstance(x, (list,tuple)):
200            return self.calculate_Iq([float(x[0])],[float(x[1])])[0]
201        else:
202            return self.calculate_Iq([float(x)])[0]
203
204    def evalDistribution(self, qdist):
205        """
206        Evaluate a distribution of q-values.
207
208        * For 1D, a numpy array is expected as input: ::
209
210            evalDistribution(q)
211
212          where q is a numpy array.
213
214        * For 2D, a list of numpy arrays are expected: [qx,qy],
215          with 1D arrays::
216
217              qx = [ qx[0], qx[1], qx[2], ....]
218
219          and::
220
221              qy = [ qy[0], qy[1], qy[2], ....]
222
223        Then get ::
224
225            q = numpy.sqrt(qx^2+qy^2)
226
227        that is a qr in 1D array::
228
229            q = [q[0], q[1], q[2], ....]
230
231
232        :param qdist: ndarray of scalar q-values or list [qx,qy] where qx,qy are 1D ndarrays
233        """
234        if isinstance(qdist, (list,tuple)):
235            # Check whether we have a list of ndarrays [qx,qy]
236            qx, qy = qdist
237            return self.calculate_Iq(qx, qy)
238
239        elif isinstance(qdist, np.ndarray):
240            # We have a simple 1D distribution of q-values
241            return self.calculate_Iq(qdist)
242
243        else:
244            raise TypeError("evalDistribution expects q or [qx, qy], not %r"%type(qdist))
245
246    def calculate_Iq(self, *args):
247        q_vectors = [np.asarray(q) for q in args]
248        fn = self._model(self._model.make_input(q_vectors))
249        pars = [self.params[v] for v in fn.fixed_pars]
250        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
251        result = fn(pars, pd_pars, self.cutoff)
252        fn.input.release()
253        fn.release()
254        return result
255
256    def calculate_ER(self):
257        """
258        Calculate the effective radius for P(q)*S(q)
259
260        :return: the value of the effective radius
261        """
262        ER = self._model.info.get('ER', None)
263        if ER is None:
264            return 1.0
265        else:
266            vol_pars = self._model.info['partype']['volume']
267            values, weights = self._dispersion_mesh(vol_pars)
268            fv = ER(*values)
269            return np.sum(weights*fv) / np.sum(weights)
270
271    def calculate_VR(self):
272        """
273        Calculate the volf ratio for P(q)*S(q)
274
275        :return: the value of the volf ratio
276        """
277        VR = self._model.info.get('ER', None)
278        if VR is None:
279            return 1.0
280        else:
281            vol_pars = self._model.info['partype']['volume']
282            values, weights = self._dispersion_mesh(vol_pars)
283            whole,part = VR(*values)
284            return np.sum(weights*part)/np.sum(weights*whole)
285
286    def set_dispersion(self, parameter, dispersion):
287        """
288        Set the dispersion object for a model parameter
289
290        :param parameter: name of the parameter [string]
291        :param dispersion: dispersion object of type Dispersion
292        """
293        if parameter.lower() in (s.lower() for s in self.params.keys()):
294            # TODO: Store the disperser object directly in the model.
295            # The current method of creating one on the fly whenever it is
296            # needed is kind of funky.
297            # Note: can't seem to get disperser parameters from sasview
298            # (1) Could create a sasview model that has not yet # been
299            # converted, assign the disperser to one of its polydisperse
300            # parameters, then retrieve the disperser parameters from the
301            # sasview model.  (2) Could write a disperser parameter retriever
302            # in sasview.  (3) Could modify sasview to use sasmodels.weights
303            # dispersers.
304            # For now, rely on the fact that the sasview only ever uses
305            # new dispersers in the set_dispersion call and create a new
306            # one instead of trying to assign parameters.
307            from . import weights
308            disperser = weights.dispersers[dispersion.__class__.__name__]
309            dispersion = weights.models[disperser]()
310            self.dispersion[parameter] = dispersion.get_pars()
311        else:
312            raise ValueError("%r is not a dispersity or orientation parameter")
313
314    def _dispersion_mesh(self, pars):
315        """
316        Create a mesh grid of dispersion parameters and weights.
317
318        Returns [p1,p2,...],w where pj is a vector of values for parameter j
319        and w is a vector containing the products for weights for each
320        parameter set in the vector.
321        """
322        values, weights = zip(*[self._get_weights(p) for p in pars])
323        values = [v.flatten() for v in np.meshgrid(*values)]
324        weights = np.vstack([v.flatten() for v in np.meshgrid(*weights)])
325        weights = np.prod(weights, axis=1)
326        return values, weights
327
328    def _get_weights(self, par):
329        from . import weights
330
331        relative = self._model.info['partype']['pd-rel']
332        limits = self._model.info['limits']
333        dis = self.dispersion[par]
334        v,w = weights.get_weights(
335            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
336            self.params[par], limits[par], par in relative)
337        return v,w/w.max()
338
Note: See TracBrowser for help on using the repository browser.