source: sasmodels/sasmodels/sasview_model.py @ 08376e7

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 08376e7 was 08376e7, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

sasview wrapper hackery

  • Property mode set to 100644
File size: 12.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15
16import math
17from copy import deepcopy
18import warnings
19import collections
20
21import numpy as np
22
23from . import core
24
25def standard_models():
26    return [make_class(model_name) for model_name in core.list_models()]
27
28def make_class(model_name, namestyle='name'):
29    """
30    Load the sasview model defined in *kernel_module*.
31
32    Returns a class that can be used directly as a sasview model.t
33
34    Defaults to using the new name for a model.  Setting
35    *namestyle='oldname'* will produce a class with a name
36    compatible with SasView.
37    """
38    model_info = core.load_model_info(model_name)
39    def __init__(self, multfactor=1):
40        SasviewModel.__init__(self)
41    attrs = dict(__init__=__init__, _model_info=model_info)
42    ConstructedModel = type(model_info[namestyle], (SasviewModel,), attrs)
43    return ConstructedModel
44
45class SasviewModel(object):
46    """
47    Sasview wrapper for opencl/ctypes model.
48    """
49    def __init__(self):
50        self._kernel = None
51        model_info = self._model_info
52
53        self.name = model_info['name']
54        self.oldname = model_info['oldname']
55        self.description = model_info['description']
56        self.category = None
57        self.multiplicity_info = None
58        self.is_multifunc = False
59
60        ## interpret the parameters
61        ## TODO: reorganize parameter handling
62        self.details = dict()
63        self.params = collections.OrderedDict()
64        self.dispersion = dict()
65        partype = model_info['partype']
66
67        for p in model_info['parameters']:
68            self.params[p.name] = p.default
69            self.details[p.name] = [p.units] + p.limits
70
71        for name in partype['pd-2d']:
72            self.dispersion[name] = {
73                'width': 0,
74                'npts': 35,
75                'nsigmas': 3,
76                'type': 'gaussian',
77            }
78
79        self.orientation_params = (
80            partype['orientation']
81            + [n + '.width' for n in partype['orientation']]
82            + partype['magnetic'])
83        self.magnetic_params = partype['magnetic']
84        self.fixed = [n + '.width' for n in partype['pd-2d']]
85        self.non_fittable = []
86
87        ## independent parameter name and unit [string]
88        self.input_name = model_info.get("input_name", "Q")
89        self.input_unit = model_info.get("input_unit", "A^{-1}")
90        self.output_name = model_info.get("output_name", "Intensity")
91        self.output_unit = model_info.get("output_unit", "cm^{-1}")
92
93        ## _persistency_dict is used by sas.perspectives.fitting.basepage
94        ## to store dispersity reference.
95        ## TODO: _persistency_dict to persistency_dict throughout sasview
96        self._persistency_dict = {}
97
98        ## New fields introduced for opencl rewrite
99        self.cutoff = 1e-5
100
101    def __get_state__(self):
102        state = self.__dict__.copy()
103        model_id = self._model_info['id']
104        state.pop('_kernel')
105        # May need to reload model info on set state since it has pointers
106        # to python implementations of Iq, etc.
107        #state.pop('_model_info')
108        return state
109
110    def __set_state__(self, state):
111        self.__dict__ = state
112        self._kernel = None
113
114    def __str__(self):
115        """
116        :return: string representation
117        """
118        return self.name
119
120    def is_fittable(self, par_name):
121        """
122        Check if a given parameter is fittable or not
123
124        :param par_name: the parameter name to check
125        """
126        return par_name.lower() in self.fixed
127        #For the future
128        #return self.params[str(par_name)].is_fittable()
129
130
131    # pylint: disable=no-self-use
132    def getProfile(self):
133        """
134        Get SLD profile
135
136        : return: (z, beta) where z is a list of depth of the transition points
137                beta is a list of the corresponding SLD values
138        """
139        return None, None
140
141    def setParam(self, name, value):
142        """
143        Set the value of a model parameter
144
145        :param name: name of the parameter
146        :param value: value of the parameter
147
148        """
149        # Look for dispersion parameters
150        toks = name.split('.')
151        if len(toks) == 2:
152            for item in self.dispersion.keys():
153                if item.lower() == toks[0].lower():
154                    for par in self.dispersion[item]:
155                        if par.lower() == toks[1].lower():
156                            self.dispersion[item][par] = value
157                            return
158        else:
159            # Look for standard parameter
160            for item in self.params.keys():
161                if item.lower() == name.lower():
162                    self.params[item] = value
163                    return
164
165        raise ValueError("Model does not contain parameter %s" % name)
166
167    def getParam(self, name):
168        """
169        Set the value of a model parameter
170
171        :param name: name of the parameter
172
173        """
174        # Look for dispersion parameters
175        toks = name.split('.')
176        if len(toks) == 2:
177            for item in self.dispersion.keys():
178                if item.lower() == toks[0].lower():
179                    for par in self.dispersion[item]:
180                        if par.lower() == toks[1].lower():
181                            return self.dispersion[item][par]
182        else:
183            # Look for standard parameter
184            for item in self.params.keys():
185                if item.lower() == name.lower():
186                    return self.params[item]
187
188        raise ValueError("Model does not contain parameter %s" % name)
189
190    def getParamList(self):
191        """
192        Return a list of all available parameters for the model
193        """
194        param_list = self.params.keys()
195        # WARNING: Extending the list with the dispersion parameters
196        param_list.extend(self.getDispParamList())
197        return param_list
198
199    def getDispParamList(self):
200        """
201        Return a list of all available parameters for the model
202        """
203        # TODO: fix test so that parameter order doesn't matter
204        ret = ['%s.%s' % (d.lower(), p)
205               for d in self._model_info['partype']['pd-2d']
206               for p in ('npts', 'nsigmas', 'width')]
207        #print(ret)
208        return ret
209
210    def clone(self):
211        """ Return a identical copy of self """
212        return deepcopy(self)
213
214    def run(self, x=0.0):
215        """
216        Evaluate the model
217
218        :param x: input q, or [q,phi]
219
220        :return: scattering function P(q)
221
222        **DEPRECATED**: use calculate_Iq instead
223        """
224        if isinstance(x, (list, tuple)):
225            # pylint: disable=unpacking-non-sequence
226            q, phi = x
227            return self.calculate_Iq([q * math.cos(phi)],
228                                     [q * math.sin(phi)])[0]
229        else:
230            return self.calculate_Iq([float(x)])[0]
231
232
233    def runXY(self, x=0.0):
234        """
235        Evaluate the model in cartesian coordinates
236
237        :param x: input q, or [qx, qy]
238
239        :return: scattering function P(q)
240
241        **DEPRECATED**: use calculate_Iq instead
242        """
243        if isinstance(x, (list, tuple)):
244            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
245        else:
246            return self.calculate_Iq([float(x)])[0]
247
248    def evalDistribution(self, qdist):
249        r"""
250        Evaluate a distribution of q-values.
251
252        :param qdist: array of q or a list of arrays [qx,qy]
253
254        * For 1D, a numpy array is expected as input
255
256        ::
257
258            evalDistribution(q)
259
260          where *q* is a numpy array.
261
262        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
263
264        ::
265
266              qx = [ qx[0], qx[1], qx[2], ....]
267              qy = [ qy[0], qy[1], qy[2], ....]
268
269        If the model is 1D only, then
270
271        .. math::
272
273            q = \sqrt{q_x^2+q_y^2}
274
275        """
276        if isinstance(qdist, (list, tuple)):
277            # Check whether we have a list of ndarrays [qx,qy]
278            qx, qy = qdist
279            partype = self._model_info['partype']
280            if not partype['orientation'] and not partype['magnetic']:
281                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
282            else:
283                return self.calculate_Iq(qx, qy)
284
285        elif isinstance(qdist, np.ndarray):
286            # We have a simple 1D distribution of q-values
287            return self.calculate_Iq(qdist)
288
289        else:
290            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
291                            % type(qdist))
292
293    def calculate_Iq(self, *args):
294        """
295        Calculate Iq for one set of q with the current parameters.
296
297        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
298
299        This should NOT be used for fitting since it copies the *q* vectors
300        to the card for each evaluation.
301        """
302        if self._kernel is None:
303            self._kernel = core.build_model(self._model_info)
304        q_vectors = [np.asarray(q) for q in args]
305        fn = self._kernel(q_vectors)
306        pars = [self.params[v] for v in fn.fixed_pars]
307        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
308        result = fn(pars, pd_pars, self.cutoff)
309        fn.q_input.release()
310        fn.release()
311        return result
312
313    def calculate_ER(self):
314        """
315        Calculate the effective radius for P(q)*S(q)
316
317        :return: the value of the effective radius
318        """
319        ER = self._model_info.get('ER', None)
320        if ER is None:
321            return 1.0
322        else:
323            values, weights = self._dispersion_mesh()
324            fv = ER(*values)
325            #print(values[0].shape, weights.shape, fv.shape)
326            return np.sum(weights * fv) / np.sum(weights)
327
328    def calculate_VR(self):
329        """
330        Calculate the volf ratio for P(q)*S(q)
331
332        :return: the value of the volf ratio
333        """
334        VR = self._model_info.get('VR', None)
335        if VR is None:
336            return 1.0
337        else:
338            values, weights = self._dispersion_mesh()
339            whole, part = VR(*values)
340            return np.sum(weights * part) / np.sum(weights * whole)
341
342    def set_dispersion(self, parameter, dispersion):
343        """
344        Set the dispersion object for a model parameter
345
346        :param parameter: name of the parameter [string]
347        :param dispersion: dispersion object of type Dispersion
348        """
349        if parameter.lower() in (s.lower() for s in self.params.keys()):
350            # TODO: Store the disperser object directly in the model.
351            # The current method of creating one on the fly whenever it is
352            # needed is kind of funky.
353            # Note: can't seem to get disperser parameters from sasview
354            # (1) Could create a sasview model that has not yet # been
355            # converted, assign the disperser to one of its polydisperse
356            # parameters, then retrieve the disperser parameters from the
357            # sasview model.  (2) Could write a disperser parameter retriever
358            # in sasview.  (3) Could modify sasview to use sasmodels.weights
359            # dispersers.
360            # For now, rely on the fact that the sasview only ever uses
361            # new dispersers in the set_dispersion call and create a new
362            # one instead of trying to assign parameters.
363            from . import weights
364            disperser = weights.dispersers[dispersion.__class__.__name__]
365            dispersion = weights.models[disperser]()
366            self.dispersion[parameter] = dispersion.get_pars()
367        else:
368            raise ValueError("%r is not a dispersity or orientation parameter")
369
370    def _dispersion_mesh(self):
371        """
372        Create a mesh grid of dispersion parameters and weights.
373
374        Returns [p1,p2,...],w where pj is a vector of values for parameter j
375        and w is a vector containing the products for weights for each
376        parameter set in the vector.
377        """
378        pars = self._model_info['partype']['volume']
379        return core.dispersion_mesh([self._get_weights(p) for p in pars])
380
381    def _get_weights(self, par):
382        """
383            Return dispersion weights
384            :param par parameter name
385        """
386        from . import weights
387
388        relative = self._model_info['partype']['pd-rel']
389        limits = self._model_info['limits']
390        dis = self.dispersion[par]
391        value, weight = weights.get_weights(
392            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
393            self.params[par], limits[par], par in relative)
394        return value, weight / np.sum(weight)
395
Note: See TracBrowser for help on using the repository browser.