source: sasmodels/sasmodels/sasview_model.py @ 4d76711

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 4d76711 was 4d76711, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

adjust interface to sasview

  • Property mode set to 100644
File size: 14.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28
29def load_standard_models():
30    """
31    Load and return the list of predefined models.
32
33    If there is an error loading a model, then a traceback is logged and the
34    model is not returned.
35    """
36    models = []
37    for name in core.list_models():
38        try:
39            models.append(_make_standard_model(name))
40        except:
41            logging.error(traceback.format_exc())
42    return models
43
44
45def load_custom_model(path):
46    """
47    Load a custom model given the model path.
48    """
49    kernel_module = custom.load_custom_kernel_module(path)
50    model_info = generate.make_model_info(kernel_module)
51    return _make_model_from_info(model_info)
52
53
54def _make_standard_model(name):
55    """
56    Load the sasview model defined by *name*.
57
58    *name* can be a standard model name or a path to a custom model.
59
60    Returns a class that can be used directly as a sasview model.
61    """
62    kernel_module = generate.load_kernel_module(name)
63    model_info = generate.make_model_info(kernel_module)
64    return _make_model_from_info(model_info)
65
66
67def _make_model_from_info(model_info):
68    """
69    Convert *model_info* into a SasView model wrapper.
70    """
71    def __init__(self, multfactor=1):
72        SasviewModel.__init__(self)
73    attrs = dict(__init__=__init__, _model_info=model_info)
74    ConstructedModel = type(model_info['name'], (SasviewModel,), attrs)
75    return ConstructedModel
76
77
78class SasviewModel(object):
79    """
80    Sasview wrapper for opencl/ctypes model.
81    """
82    _model_info = {}
83    def __init__(self):
84        self._model = None
85        model_info = self._model_info
86
87        self.name = model_info['name']
88        self.description = model_info['description']
89        self.category = None
90        self.multiplicity_info = None
91        self.is_multifunc = False
92
93        ## interpret the parameters
94        ## TODO: reorganize parameter handling
95        self.details = dict()
96        self.params = collections.OrderedDict()
97        self.dispersion = dict()
98        partype = model_info['partype']
99
100        for p in model_info['parameters']:
101            self.params[p.name] = p.default
102            self.details[p.name] = [p.units] + p.limits
103
104        for name in partype['pd-2d']:
105            self.dispersion[name] = {
106                'width': 0,
107                'npts': 35,
108                'nsigmas': 3,
109                'type': 'gaussian',
110            }
111
112        self.orientation_params = (
113            partype['orientation']
114            + [n + '.width' for n in partype['orientation']]
115            + partype['magnetic'])
116        self.magnetic_params = partype['magnetic']
117        self.fixed = [n + '.width' for n in partype['pd-2d']]
118        self.non_fittable = []
119
120        ## independent parameter name and unit [string]
121        self.input_name = model_info.get("input_name", "Q")
122        self.input_unit = model_info.get("input_unit", "A^{-1}")
123        self.output_name = model_info.get("output_name", "Intensity")
124        self.output_unit = model_info.get("output_unit", "cm^{-1}")
125
126        ## _persistency_dict is used by sas.perspectives.fitting.basepage
127        ## to store dispersity reference.
128        ## TODO: _persistency_dict to persistency_dict throughout sasview
129        self._persistency_dict = {}
130
131        ## New fields introduced for opencl rewrite
132        self.cutoff = 1e-5
133
134    def __get_state__(self):
135        state = self.__dict__.copy()
136        state.pop('_model')
137        # May need to reload model info on set state since it has pointers
138        # to python implementations of Iq, etc.
139        #state.pop('_model_info')
140        return state
141
142    def __set_state__(self, state):
143        self.__dict__ = state
144        self._model = None
145
146    def __str__(self):
147        """
148        :return: string representation
149        """
150        return self.name
151
152    def is_fittable(self, par_name):
153        """
154        Check if a given parameter is fittable or not
155
156        :param par_name: the parameter name to check
157        """
158        return par_name.lower() in self.fixed
159        #For the future
160        #return self.params[str(par_name)].is_fittable()
161
162
163    # pylint: disable=no-self-use
164    def getProfile(self):
165        """
166        Get SLD profile
167
168        : return: (z, beta) where z is a list of depth of the transition points
169                beta is a list of the corresponding SLD values
170        """
171        return None, None
172
173    def setParam(self, name, value):
174        """
175        Set the value of a model parameter
176
177        :param name: name of the parameter
178        :param value: value of the parameter
179
180        """
181        # Look for dispersion parameters
182        toks = name.split('.')
183        if len(toks) == 2:
184            for item in self.dispersion.keys():
185                if item.lower() == toks[0].lower():
186                    for par in self.dispersion[item]:
187                        if par.lower() == toks[1].lower():
188                            self.dispersion[item][par] = value
189                            return
190        else:
191            # Look for standard parameter
192            for item in self.params.keys():
193                if item.lower() == name.lower():
194                    self.params[item] = value
195                    return
196
197        raise ValueError("Model does not contain parameter %s" % name)
198
199    def getParam(self, name):
200        """
201        Set the value of a model parameter
202
203        :param name: name of the parameter
204
205        """
206        # Look for dispersion parameters
207        toks = name.split('.')
208        if len(toks) == 2:
209            for item in self.dispersion.keys():
210                if item.lower() == toks[0].lower():
211                    for par in self.dispersion[item]:
212                        if par.lower() == toks[1].lower():
213                            return self.dispersion[item][par]
214        else:
215            # Look for standard parameter
216            for item in self.params.keys():
217                if item.lower() == name.lower():
218                    return self.params[item]
219
220        raise ValueError("Model does not contain parameter %s" % name)
221
222    def getParamList(self):
223        """
224        Return a list of all available parameters for the model
225        """
226        param_list = self.params.keys()
227        # WARNING: Extending the list with the dispersion parameters
228        param_list.extend(self.getDispParamList())
229        return param_list
230
231    def getDispParamList(self):
232        """
233        Return a list of polydispersity parameters for the model
234        """
235        # TODO: fix test so that parameter order doesn't matter
236        ret = ['%s.%s' % (d.lower(), p)
237               for d in self._model_info['partype']['pd-2d']
238               for p in ('npts', 'nsigmas', 'width')]
239        #print(ret)
240        return ret
241
242    def clone(self):
243        """ Return a identical copy of self """
244        return deepcopy(self)
245
246    def run(self, x=0.0):
247        """
248        Evaluate the model
249
250        :param x: input q, or [q,phi]
251
252        :return: scattering function P(q)
253
254        **DEPRECATED**: use calculate_Iq instead
255        """
256        if isinstance(x, (list, tuple)):
257            # pylint: disable=unpacking-non-sequence
258            q, phi = x
259            return self.calculate_Iq([q * math.cos(phi)],
260                                     [q * math.sin(phi)])[0]
261        else:
262            return self.calculate_Iq([float(x)])[0]
263
264
265    def runXY(self, x=0.0):
266        """
267        Evaluate the model in cartesian coordinates
268
269        :param x: input q, or [qx, qy]
270
271        :return: scattering function P(q)
272
273        **DEPRECATED**: use calculate_Iq instead
274        """
275        if isinstance(x, (list, tuple)):
276            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
277        else:
278            return self.calculate_Iq([float(x)])[0]
279
280    def evalDistribution(self, qdist):
281        r"""
282        Evaluate a distribution of q-values.
283
284        :param qdist: array of q or a list of arrays [qx,qy]
285
286        * For 1D, a numpy array is expected as input
287
288        ::
289
290            evalDistribution(q)
291
292          where *q* is a numpy array.
293
294        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
295
296        ::
297
298              qx = [ qx[0], qx[1], qx[2], ....]
299              qy = [ qy[0], qy[1], qy[2], ....]
300
301        If the model is 1D only, then
302
303        .. math::
304
305            q = \sqrt{q_x^2+q_y^2}
306
307        """
308        if isinstance(qdist, (list, tuple)):
309            # Check whether we have a list of ndarrays [qx,qy]
310            qx, qy = qdist
311            partype = self._model_info['partype']
312            if not partype['orientation'] and not partype['magnetic']:
313                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
314            else:
315                return self.calculate_Iq(qx, qy)
316
317        elif isinstance(qdist, np.ndarray):
318            # We have a simple 1D distribution of q-values
319            return self.calculate_Iq(qdist)
320
321        else:
322            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
323                            % type(qdist))
324
325    def calculate_Iq(self, *args):
326        """
327        Calculate Iq for one set of q with the current parameters.
328
329        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
330
331        This should NOT be used for fitting since it copies the *q* vectors
332        to the card for each evaluation.
333        """
334        if self._model is None:
335            self._model = core.build_model(self._model_info)
336        q_vectors = [np.asarray(q) for q in args]
337        fn = self._model.make_kernel(q_vectors)
338        pars = [self.params[v] for v in fn.fixed_pars]
339        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
340        result = fn(pars, pd_pars, self.cutoff)
341        fn.q_input.release()
342        fn.release()
343        return result
344
345    def calculate_ER(self):
346        """
347        Calculate the effective radius for P(q)*S(q)
348
349        :return: the value of the effective radius
350        """
351        ER = self._model_info.get('ER', None)
352        if ER is None:
353            return 1.0
354        else:
355            values, weights = self._dispersion_mesh()
356            fv = ER(*values)
357            #print(values[0].shape, weights.shape, fv.shape)
358            return np.sum(weights * fv) / np.sum(weights)
359
360    def calculate_VR(self):
361        """
362        Calculate the volf ratio for P(q)*S(q)
363
364        :return: the value of the volf ratio
365        """
366        VR = self._model_info.get('VR', None)
367        if VR is None:
368            return 1.0
369        else:
370            values, weights = self._dispersion_mesh()
371            whole, part = VR(*values)
372            return np.sum(weights * part) / np.sum(weights * whole)
373
374    def set_dispersion(self, parameter, dispersion):
375        """
376        Set the dispersion object for a model parameter
377
378        :param parameter: name of the parameter [string]
379        :param dispersion: dispersion object of type Dispersion
380        """
381        if parameter.lower() in (s.lower() for s in self.params.keys()):
382            # TODO: Store the disperser object directly in the model.
383            # The current method of creating one on the fly whenever it is
384            # needed is kind of funky.
385            # Note: can't seem to get disperser parameters from sasview
386            # (1) Could create a sasview model that has not yet # been
387            # converted, assign the disperser to one of its polydisperse
388            # parameters, then retrieve the disperser parameters from the
389            # sasview model.  (2) Could write a disperser parameter retriever
390            # in sasview.  (3) Could modify sasview to use sasmodels.weights
391            # dispersers.
392            # For now, rely on the fact that the sasview only ever uses
393            # new dispersers in the set_dispersion call and create a new
394            # one instead of trying to assign parameters.
395            from . import weights
396            disperser = weights.dispersers[dispersion.__class__.__name__]
397            dispersion = weights.models[disperser]()
398            self.dispersion[parameter] = dispersion.get_pars()
399        else:
400            raise ValueError("%r is not a dispersity or orientation parameter")
401
402    def _dispersion_mesh(self):
403        """
404        Create a mesh grid of dispersion parameters and weights.
405
406        Returns [p1,p2,...],w where pj is a vector of values for parameter j
407        and w is a vector containing the products for weights for each
408        parameter set in the vector.
409        """
410        pars = self._model_info['partype']['volume']
411        return core.dispersion_mesh([self._get_weights(p) for p in pars])
412
413    def _get_weights(self, par):
414        """
415        Return dispersion weights for parameter
416        """
417        from . import weights
418        relative = self._model_info['partype']['pd-rel']
419        limits = self._model_info['limits']
420        dis = self.dispersion[par]
421        value, weight = weights.get_weights(
422            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
423            self.params[par], limits[par], par in relative)
424        return value, weight / np.sum(weight)
425
426
427def test_model():
428    """
429    Test that a sasview model (cylinder) can be run.
430    """
431    Cylinder = _make_standard_model('cylinder')
432    cylinder = Cylinder()
433    return cylinder.evalDistribution([0.1,0.1])
434
435
436def test_model_list():
437    """
438    Make sure that all models build as sasview models.
439    """
440    from .exception import annotate_exception
441    for name in core.list_models():
442        try:
443            _make_standard_model(name)
444        except:
445            annotate_exception("when loading "+name)
446            raise
447
448if __name__ == "__main__":
449    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.