source: sasmodels/sasmodels/sasview_model.py @ 2622b3f

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2622b3f was 2622b3f, checked in by gonzalezm, 8 years ago

Use ordered dict to preserve order of parameters as given in the model

  • Property mode set to 100644
File size: 12.2 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15
16import math
17from copy import deepcopy
18import warnings
19import collections
20
21import numpy as np
22
23from . import core
24
25def make_class(model_info, dtype='single', namestyle='name'):
26    """
27    Load the sasview model defined in *kernel_module*.
28
29    Returns a class that can be used directly as a sasview model.
30
31    Defaults to using the new name for a model.  Setting
32    *namestyle='oldname'* will produce a class with a name
33    compatible with SasView.
34    """
35    model = core.build_model(model_info, dtype=dtype)
36    def __init__(self, multfactor=1):
37        SasviewModel.__init__(self, model)
38    attrs = dict(__init__=__init__)
39    ConstructedModel = type(model.info[namestyle], (SasviewModel,), attrs)
40    return ConstructedModel
41
42class SasviewModel(object):
43    """
44    Sasview wrapper for opencl/ctypes model.
45    """
46    def __init__(self, model):
47        """Initialization"""
48        self._model = model
49
50        self.name = model.info['name']
51        self.oldname = model.info['oldname']
52        self.description = model.info['description']
53        self.category = None
54        self.multiplicity_info = None
55        self.is_multifunc = False
56
57        ## interpret the parameters
58        ## TODO: reorganize parameter handling
59        self.details = dict()
60        self.params = collections.OrderedDict()
61        self.dispersion = dict()
62        partype = model.info['partype']
63
64        for p in model.info['parameters']:
65            self.params[p.name] = p.default
66            self.details[p.name] = [p.units] + p.limits
67
68        for name in partype['pd-2d']:
69            self.dispersion[name] = {
70                'width': 0,
71                'npts': 35,
72                'nsigmas': 3,
73                'type': 'gaussian',
74            }
75
76        self.orientation_params = (
77            partype['orientation']
78            + [n + '.width' for n in partype['orientation']]
79            + partype['magnetic'])
80        self.magnetic_params = partype['magnetic']
81        self.fixed = [n + '.width' for n in partype['pd-2d']]
82        self.non_fittable = []
83
84        ## independent parameter name and unit [string]
85        self.input_name = model.info.get("input_name", "Q")
86        self.input_unit = model.info.get("input_unit", "A^{-1}")
87        self.output_name = model.info.get("output_name", "Intensity")
88        self.output_unit = model.info.get("output_unit", "cm^{-1}")
89
90        ## _persistency_dict is used by sas.perspectives.fitting.basepage
91        ## to store dispersity reference.
92        ## TODO: _persistency_dict to persistency_dict throughout sasview
93        self._persistency_dict = {}
94
95        ## New fields introduced for opencl rewrite
96        self.cutoff = 1e-5
97
98    def __str__(self):
99        """
100        :return: string representation
101        """
102        return self.name
103
104    def is_fittable(self, par_name):
105        """
106        Check if a given parameter is fittable or not
107
108        :param par_name: the parameter name to check
109        """
110        return par_name.lower() in self.fixed
111        #For the future
112        #return self.params[str(par_name)].is_fittable()
113
114
115    # pylint: disable=no-self-use
116    def getProfile(self):
117        """
118        Get SLD profile
119
120        : return: (z, beta) where z is a list of depth of the transition points
121                beta is a list of the corresponding SLD values
122        """
123        return None, None
124
125    def setParam(self, name, value):
126        """
127        Set the value of a model parameter
128
129        :param name: name of the parameter
130        :param value: value of the parameter
131
132        """
133        # Look for dispersion parameters
134        toks = name.split('.')
135        if len(toks) == 2:
136            for item in self.dispersion.keys():
137                if item.lower() == toks[0].lower():
138                    for par in self.dispersion[item]:
139                        if par.lower() == toks[1].lower():
140                            self.dispersion[item][par] = value
141                            return
142        else:
143            # Look for standard parameter
144            for item in self.params.keys():
145                if item.lower() == name.lower():
146                    self.params[item] = value
147                    return
148
149        raise ValueError("Model does not contain parameter %s" % name)
150
151    def getParam(self, name):
152        """
153        Set the value of a model parameter
154
155        :param name: name of the parameter
156
157        """
158        # Look for dispersion parameters
159        toks = name.split('.')
160        if len(toks) == 2:
161            for item in self.dispersion.keys():
162                if item.lower() == toks[0].lower():
163                    for par in self.dispersion[item]:
164                        if par.lower() == toks[1].lower():
165                            return self.dispersion[item][par]
166        else:
167            # Look for standard parameter
168            for item in self.params.keys():
169                if item.lower() == name.lower():
170                    return self.params[item]
171
172        raise ValueError("Model does not contain parameter %s" % name)
173
174    def getParamList(self):
175        """
176        Return a list of all available parameters for the model
177        """
178        param_list = self.params.keys()
179        # WARNING: Extending the list with the dispersion parameters
180        param_list.extend(self.getDispParamList())
181        return param_list
182
183    def getDispParamList(self):
184        """
185        Return a list of all available parameters for the model
186        """
187        # TODO: fix test so that parameter order doesn't matter
188        ret = ['%s.%s' % (d.lower(), p)
189               for d in self._model.info['partype']['pd-2d']
190               for p in ('npts', 'nsigmas', 'width')]
191        #print(ret)
192        return ret
193
194    def clone(self):
195        """ Return a identical copy of self """
196        return deepcopy(self)
197
198    def run(self, x=0.0):
199        """
200        Evaluate the model
201
202        :param x: input q, or [q,phi]
203
204        :return: scattering function P(q)
205
206        **DEPRECATED**: use calculate_Iq instead
207        """
208        if isinstance(x, (list, tuple)):
209            # pylint: disable=unpacking-non-sequence
210            q, phi = x
211            return self.calculate_Iq([q * math.cos(phi)],
212                                     [q * math.sin(phi)])[0]
213        else:
214            return self.calculate_Iq([float(x)])[0]
215
216
217    def runXY(self, x=0.0):
218        """
219        Evaluate the model in cartesian coordinates
220
221        :param x: input q, or [qx, qy]
222
223        :return: scattering function P(q)
224
225        **DEPRECATED**: use calculate_Iq instead
226        """
227        if isinstance(x, (list, tuple)):
228            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
229        else:
230            return self.calculate_Iq([float(x)])[0]
231
232    def evalDistribution(self, qdist):
233        r"""
234        Evaluate a distribution of q-values.
235
236        :param qdist: array of q or a list of arrays [qx,qy]
237
238        * For 1D, a numpy array is expected as input
239
240        ::
241
242            evalDistribution(q)
243
244          where *q* is a numpy array.
245
246        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
247
248        ::
249
250              qx = [ qx[0], qx[1], qx[2], ....]
251              qy = [ qy[0], qy[1], qy[2], ....]
252
253        If the model is 1D only, then
254
255        .. math::
256
257            q = \sqrt{q_x^2+q_y^2}
258
259        """
260        if isinstance(qdist, (list, tuple)):
261            # Check whether we have a list of ndarrays [qx,qy]
262            qx, qy = qdist
263            partype = self._model.info['partype']
264            if not partype['orientation'] and not partype['magnetic']:
265                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
266            else:
267                return self.calculate_Iq(qx, qy)
268
269        elif isinstance(qdist, np.ndarray):
270            # We have a simple 1D distribution of q-values
271            return self.calculate_Iq(qdist)
272
273        else:
274            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
275                            % type(qdist))
276
277    def calculate_Iq(self, *args):
278        """
279        Calculate Iq for one set of q with the current parameters.
280
281        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
282
283        This should NOT be used for fitting since it copies the *q* vectors
284        to the card for each evaluation.
285        """
286        q_vectors = [np.asarray(q) for q in args]
287        fn = self._model(q_vectors)
288        pars = [self.params[v] for v in fn.fixed_pars]
289        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
290        result = fn(pars, pd_pars, self.cutoff)
291        fn.q_input.release()
292        fn.release()
293        return result
294
295    def calculate_ER(self):
296        """
297        Calculate the effective radius for P(q)*S(q)
298
299        :return: the value of the effective radius
300        """
301        ER = self._model.info.get('ER', None)
302        if ER is None:
303            return 1.0
304        else:
305            values, weights = self._dispersion_mesh()
306            fv = ER(*values)
307            #print(values[0].shape, weights.shape, fv.shape)
308            return np.sum(weights * fv) / np.sum(weights)
309
310    def calculate_VR(self):
311        """
312        Calculate the volf ratio for P(q)*S(q)
313
314        :return: the value of the volf ratio
315        """
316        VR = self._model.info.get('VR', None)
317        if VR is None:
318            return 1.0
319        else:
320            values, weights = self._dispersion_mesh()
321            whole, part = VR(*values)
322            return np.sum(weights * part) / np.sum(weights * whole)
323
324    def set_dispersion(self, parameter, dispersion):
325        """
326        Set the dispersion object for a model parameter
327
328        :param parameter: name of the parameter [string]
329        :param dispersion: dispersion object of type Dispersion
330        """
331        if parameter.lower() in (s.lower() for s in self.params.keys()):
332            # TODO: Store the disperser object directly in the model.
333            # The current method of creating one on the fly whenever it is
334            # needed is kind of funky.
335            # Note: can't seem to get disperser parameters from sasview
336            # (1) Could create a sasview model that has not yet # been
337            # converted, assign the disperser to one of its polydisperse
338            # parameters, then retrieve the disperser parameters from the
339            # sasview model.  (2) Could write a disperser parameter retriever
340            # in sasview.  (3) Could modify sasview to use sasmodels.weights
341            # dispersers.
342            # For now, rely on the fact that the sasview only ever uses
343            # new dispersers in the set_dispersion call and create a new
344            # one instead of trying to assign parameters.
345            from . import weights
346            disperser = weights.dispersers[dispersion.__class__.__name__]
347            dispersion = weights.models[disperser]()
348            self.dispersion[parameter] = dispersion.get_pars()
349        else:
350            raise ValueError("%r is not a dispersity or orientation parameter")
351
352    def _dispersion_mesh(self):
353        """
354        Create a mesh grid of dispersion parameters and weights.
355
356        Returns [p1,p2,...],w where pj is a vector of values for parameter j
357        and w is a vector containing the products for weights for each
358        parameter set in the vector.
359        """
360        pars = self._model.info['partype']['volume']
361        return core.dispersion_mesh([self._get_weights(p) for p in pars])
362
363    def _get_weights(self, par):
364        """
365            Return dispersion weights
366            :param par parameter name
367        """
368        from . import weights
369
370        relative = self._model.info['partype']['pd-rel']
371        limits = self._model.info['limits']
372        dis = self.dispersion[par]
373        value, weight = weights.get_weights(
374            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
375            self.params[par], limits[par], par in relative)
376        return value, weight / np.sum(weight)
Note: See TracBrowser for help on using the repository browser.