source: sasmodels/sasmodels/sasview_model.py @ 352964b

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 352964b was 787be86, checked in by ajj, 8 years ago

remove import custom that is causing breakage

  • Property mode set to 100644
File size: 13.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15
16import math
17from copy import deepcopy
18import collections
19
20import numpy as np
21
22from . import core
23from . import generate
24
25def standard_models():
26    return [make_class(model_name) for model_name in core.list_models()]
27
28# TODO: rename to make_class_from_name and update sasview
29def make_class(model_name):
30    """
31    Load the sasview model defined in *kernel_module*.
32
33    Returns a class that can be used directly as a sasview model.t
34
35    Defaults to using the new name for a model.  Setting
36    *namestyle='oldname'* will produce a class with a name
37    compatible with SasView.
38    """
39    model_info = core.load_model_info(model_name)
40    return make_class_from_info(model_info)
41
42def make_class_from_file(path):
43    model_info = core.load_model_info_from_path(path)
44    return make_class_from_info(model_info)
45
46def make_class_from_info(model_info):
47    def __init__(self, multfactor=1):
48        SasviewModel.__init__(self)
49    attrs = dict(__init__=__init__, _model_info=model_info)
50    ConstructedModel = type(model_info['name'], (SasviewModel,), attrs)
51    return ConstructedModel
52
53class SasviewModel(object):
54    """
55    Sasview wrapper for opencl/ctypes model.
56    """
57    def __init__(self):
58        self._kernel = None
59        model_info = self._model_info
60
61        self.name = model_info['name']
62        self.oldname = model_info['oldname']
63        self.description = model_info['description']
64        self.category = None
65        self.multiplicity_info = None
66        self.is_multifunc = False
67
68        ## interpret the parameters
69        ## TODO: reorganize parameter handling
70        self.details = dict()
71        self.params = collections.OrderedDict()
72        self.dispersion = dict()
73        partype = model_info['partype']
74
75        for p in model_info['parameters']:
76            self.params[p.name] = p.default
77            self.details[p.name] = [p.units] + p.limits
78
79        for name in partype['pd-2d']:
80            self.dispersion[name] = {
81                'width': 0,
82                'npts': 35,
83                'nsigmas': 3,
84                'type': 'gaussian',
85            }
86
87        self.orientation_params = (
88            partype['orientation']
89            + [n + '.width' for n in partype['orientation']]
90            + partype['magnetic'])
91        self.magnetic_params = partype['magnetic']
92        self.fixed = [n + '.width' for n in partype['pd-2d']]
93        self.non_fittable = []
94
95        ## independent parameter name and unit [string]
96        self.input_name = model_info.get("input_name", "Q")
97        self.input_unit = model_info.get("input_unit", "A^{-1}")
98        self.output_name = model_info.get("output_name", "Intensity")
99        self.output_unit = model_info.get("output_unit", "cm^{-1}")
100
101        ## _persistency_dict is used by sas.perspectives.fitting.basepage
102        ## to store dispersity reference.
103        ## TODO: _persistency_dict to persistency_dict throughout sasview
104        self._persistency_dict = {}
105
106        ## New fields introduced for opencl rewrite
107        self.cutoff = 1e-5
108
109    def __get_state__(self):
110        state = self.__dict__.copy()
111        model_id = self._model_info['id']
112        state.pop('_kernel')
113        # May need to reload model info on set state since it has pointers
114        # to python implementations of Iq, etc.
115        #state.pop('_model_info')
116        return state
117
118    def __set_state__(self, state):
119        self.__dict__ = state
120        self._kernel = None
121
122    def __str__(self):
123        """
124        :return: string representation
125        """
126        return self.name
127
128    def is_fittable(self, par_name):
129        """
130        Check if a given parameter is fittable or not
131
132        :param par_name: the parameter name to check
133        """
134        return par_name.lower() in self.fixed
135        #For the future
136        #return self.params[str(par_name)].is_fittable()
137
138
139    # pylint: disable=no-self-use
140    def getProfile(self):
141        """
142        Get SLD profile
143
144        : return: (z, beta) where z is a list of depth of the transition points
145                beta is a list of the corresponding SLD values
146        """
147        return None, None
148
149    def setParam(self, name, value):
150        """
151        Set the value of a model parameter
152
153        :param name: name of the parameter
154        :param value: value of the parameter
155
156        """
157        # Look for dispersion parameters
158        toks = name.split('.')
159        if len(toks) == 2:
160            for item in self.dispersion.keys():
161                if item.lower() == toks[0].lower():
162                    for par in self.dispersion[item]:
163                        if par.lower() == toks[1].lower():
164                            self.dispersion[item][par] = value
165                            return
166        else:
167            # Look for standard parameter
168            for item in self.params.keys():
169                if item.lower() == name.lower():
170                    self.params[item] = value
171                    return
172
173        raise ValueError("Model does not contain parameter %s" % name)
174
175    def getParam(self, name):
176        """
177        Set the value of a model parameter
178
179        :param name: name of the parameter
180
181        """
182        # Look for dispersion parameters
183        toks = name.split('.')
184        if len(toks) == 2:
185            for item in self.dispersion.keys():
186                if item.lower() == toks[0].lower():
187                    for par in self.dispersion[item]:
188                        if par.lower() == toks[1].lower():
189                            return self.dispersion[item][par]
190        else:
191            # Look for standard parameter
192            for item in self.params.keys():
193                if item.lower() == name.lower():
194                    return self.params[item]
195
196        raise ValueError("Model does not contain parameter %s" % name)
197
198    def getParamList(self):
199        """
200        Return a list of all available parameters for the model
201        """
202        param_list = self.params.keys()
203        # WARNING: Extending the list with the dispersion parameters
204        param_list.extend(self.getDispParamList())
205        return param_list
206
207    def getDispParamList(self):
208        """
209        Return a list of all available parameters for the model
210        """
211        # TODO: fix test so that parameter order doesn't matter
212        ret = ['%s.%s' % (d.lower(), p)
213               for d in self._model_info['partype']['pd-2d']
214               for p in ('npts', 'nsigmas', 'width')]
215        #print(ret)
216        return ret
217
218    def clone(self):
219        """ Return a identical copy of self """
220        return deepcopy(self)
221
222    def run(self, x=0.0):
223        """
224        Evaluate the model
225
226        :param x: input q, or [q,phi]
227
228        :return: scattering function P(q)
229
230        **DEPRECATED**: use calculate_Iq instead
231        """
232        if isinstance(x, (list, tuple)):
233            # pylint: disable=unpacking-non-sequence
234            q, phi = x
235            return self.calculate_Iq([q * math.cos(phi)],
236                                     [q * math.sin(phi)])[0]
237        else:
238            return self.calculate_Iq([float(x)])[0]
239
240
241    def runXY(self, x=0.0):
242        """
243        Evaluate the model in cartesian coordinates
244
245        :param x: input q, or [qx, qy]
246
247        :return: scattering function P(q)
248
249        **DEPRECATED**: use calculate_Iq instead
250        """
251        if isinstance(x, (list, tuple)):
252            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
253        else:
254            return self.calculate_Iq([float(x)])[0]
255
256    def evalDistribution(self, qdist):
257        r"""
258        Evaluate a distribution of q-values.
259
260        :param qdist: array of q or a list of arrays [qx,qy]
261
262        * For 1D, a numpy array is expected as input
263
264        ::
265
266            evalDistribution(q)
267
268          where *q* is a numpy array.
269
270        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
271
272        ::
273
274              qx = [ qx[0], qx[1], qx[2], ....]
275              qy = [ qy[0], qy[1], qy[2], ....]
276
277        If the model is 1D only, then
278
279        .. math::
280
281            q = \sqrt{q_x^2+q_y^2}
282
283        """
284        if isinstance(qdist, (list, tuple)):
285            # Check whether we have a list of ndarrays [qx,qy]
286            qx, qy = qdist
287            partype = self._model_info['partype']
288            if not partype['orientation'] and not partype['magnetic']:
289                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
290            else:
291                return self.calculate_Iq(qx, qy)
292
293        elif isinstance(qdist, np.ndarray):
294            # We have a simple 1D distribution of q-values
295            return self.calculate_Iq(qdist)
296
297        else:
298            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
299                            % type(qdist))
300
301    def calculate_Iq(self, *args):
302        """
303        Calculate Iq for one set of q with the current parameters.
304
305        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
306
307        This should NOT be used for fitting since it copies the *q* vectors
308        to the card for each evaluation.
309        """
310        if self._kernel is None:
311            self._kernel = core.build_model(self._model_info)
312        q_vectors = [np.asarray(q) for q in args]
313        fn = self._kernel(q_vectors)
314        pars = [self.params[v] for v in fn.fixed_pars]
315        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
316        result = fn(pars, pd_pars, self.cutoff)
317        fn.q_input.release()
318        fn.release()
319        return result
320
321    def calculate_ER(self):
322        """
323        Calculate the effective radius for P(q)*S(q)
324
325        :return: the value of the effective radius
326        """
327        ER = self._model_info.get('ER', None)
328        if ER is None:
329            return 1.0
330        else:
331            values, weights = self._dispersion_mesh()
332            fv = ER(*values)
333            #print(values[0].shape, weights.shape, fv.shape)
334            return np.sum(weights * fv) / np.sum(weights)
335
336    def calculate_VR(self):
337        """
338        Calculate the volf ratio for P(q)*S(q)
339
340        :return: the value of the volf ratio
341        """
342        VR = self._model_info.get('VR', None)
343        if VR is None:
344            return 1.0
345        else:
346            values, weights = self._dispersion_mesh()
347            whole, part = VR(*values)
348            return np.sum(weights * part) / np.sum(weights * whole)
349
350    def set_dispersion(self, parameter, dispersion):
351        """
352        Set the dispersion object for a model parameter
353
354        :param parameter: name of the parameter [string]
355        :param dispersion: dispersion object of type Dispersion
356        """
357        if parameter.lower() in (s.lower() for s in self.params.keys()):
358            # TODO: Store the disperser object directly in the model.
359            # The current method of creating one on the fly whenever it is
360            # needed is kind of funky.
361            # Note: can't seem to get disperser parameters from sasview
362            # (1) Could create a sasview model that has not yet # been
363            # converted, assign the disperser to one of its polydisperse
364            # parameters, then retrieve the disperser parameters from the
365            # sasview model.  (2) Could write a disperser parameter retriever
366            # in sasview.  (3) Could modify sasview to use sasmodels.weights
367            # dispersers.
368            # For now, rely on the fact that the sasview only ever uses
369            # new dispersers in the set_dispersion call and create a new
370            # one instead of trying to assign parameters.
371            from . import weights
372            disperser = weights.dispersers[dispersion.__class__.__name__]
373            dispersion = weights.models[disperser]()
374            self.dispersion[parameter] = dispersion.get_pars()
375        else:
376            raise ValueError("%r is not a dispersity or orientation parameter")
377
378    def _dispersion_mesh(self):
379        """
380        Create a mesh grid of dispersion parameters and weights.
381
382        Returns [p1,p2,...],w where pj is a vector of values for parameter j
383        and w is a vector containing the products for weights for each
384        parameter set in the vector.
385        """
386        pars = self._model_info['partype']['volume']
387        return core.dispersion_mesh([self._get_weights(p) for p in pars])
388
389    def _get_weights(self, par):
390        """
391            Return dispersion weights
392            :param par parameter name
393        """
394        from . import weights
395
396        relative = self._model_info['partype']['pd-rel']
397        limits = self._model_info['limits']
398        dis = self.dispersion[par]
399        value, weight = weights.get_weights(
400            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
401            self.params[par], limits[par], par in relative)
402        return value, weight / np.sum(weight)
403
Note: See TracBrowser for help on using the repository browser.