source: sasmodels/sasmodels/sasview_model.py @ 1e2a1ba

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1e2a1ba was 1e2a1ba, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

Merge remote-tracking branch 'origin/master' into polydisp

Conflicts:

sasmodels/core.py
sasmodels/custom/init.py
sasmodels/direct_model.py
sasmodels/generate.py
sasmodels/kernelpy.py
sasmodels/sasview_model.py

  • Property mode set to 100644
File size: 14.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29
30def load_standard_models():
31    """
32    Load and return the list of predefined models.
33
34    If there is an error loading a model, then a traceback is logged and the
35    model is not returned.
36    """
37    models = []
38    for name in core.list_models():
39        try:
40            models.append(_make_standard_model(name))
41        except:
42            logging.error(traceback.format_exc())
43    return models
44
45
46def load_custom_model(path):
47    """
48    Load a custom model given the model path.
49    """
50    kernel_module = custom.load_custom_kernel_module(path)
51    model_info = generate.make_model_info(kernel_module)
52    return _make_model_from_info(model_info)
53
54
55def _make_standard_model(name):
56    """
57    Load the sasview model defined by *name*.
58
59    *name* can be a standard model name or a path to a custom model.
60
61    Returns a class that can be used directly as a sasview model.
62    """
63    kernel_module = generate.load_kernel_module(name)
64    model_info = generate.make_model_info(kernel_module)
65    return _make_model_from_info(model_info)
66
67
68def _make_model_from_info(model_info):
69    """
70    Convert *model_info* into a SasView model wrapper.
71    """
72    def __init__(self, multfactor=1):
73        SasviewModel.__init__(self)
74    attrs = dict(__init__=__init__, _model_info=model_info)
75    ConstructedModel = type(model_info['name'], (SasviewModel,), attrs)
76    return ConstructedModel
77
78
79class SasviewModel(object):
80    """
81    Sasview wrapper for opencl/ctypes model.
82    """
83    _model_info = {}
84    def __init__(self):
85        self._model = None
86        model_info = self._model_info
87        parameters = model_info['parameters']
88
89        self.name = model_info['name']
90        self.description = model_info['description']
91        self.category = None
92        #self.is_multifunc = False
93        for p in parameters.kernel_parameters:
94            if p.is_control:
95                profile_axes = model_info['profile_axes']
96                self.multiplicity_info = [
97                    p.limits[1], p.name, p.choices, profile_axes[0]
98                    ]
99                break
100        else:
101            self.multiplicity_info = []
102
103        ## interpret the parameters
104        ## TODO: reorganize parameter handling
105        self.details = dict()
106        self.params = collections.OrderedDict()
107        self.dispersion = dict()
108
109        self.orientation_params = []
110        self.magnetic_params = []
111        self.fixed = []
112        for p in parameters.user_parameters():
113            self.params[p.name] = p.default
114            self.details[p.name] = [p.units] + p.limits
115            if p.polydisperse:
116                self.dispersion[p.name] = {
117                    'width': 0,
118                    'npts': 35,
119                    'nsigmas': 3,
120                    'type': 'gaussian',
121                }
122            if p.type == 'orientation':
123                self.orientation_params.append(p.name)
124                self.orientation_params.append(p.name+".width")
125                self.fixed.append(p.name+".width")
126            if p.type == 'magnetic':
127                self.orientation_params.append(p.name)
128                self.magnetic_params.append(p.name)
129                self.fixed.append(p.name+".width")
130
131        self.non_fittable = []
132
133        ## independent parameter name and unit [string]
134        self.input_name = model_info.get("input_name", "Q")
135        self.input_unit = model_info.get("input_unit", "A^{-1}")
136        self.output_name = model_info.get("output_name", "Intensity")
137        self.output_unit = model_info.get("output_unit", "cm^{-1}")
138
139        ## _persistency_dict is used by sas.perspectives.fitting.basepage
140        ## to store dispersity reference.
141        ## TODO: _persistency_dict to persistency_dict throughout sasview
142        self._persistency_dict = {}
143
144        ## New fields introduced for opencl rewrite
145        self.cutoff = 1e-5
146
147    def __get_state__(self):
148        state = self.__dict__.copy()
149        state.pop('_model')
150        # May need to reload model info on set state since it has pointers
151        # to python implementations of Iq, etc.
152        #state.pop('_model_info')
153        return state
154
155    def __set_state__(self, state):
156        self.__dict__ = state
157        self._model = None
158
159    def __str__(self):
160        """
161        :return: string representation
162        """
163        return self.name
164
165    def is_fittable(self, par_name):
166        """
167        Check if a given parameter is fittable or not
168
169        :param par_name: the parameter name to check
170        """
171        return par_name.lower() in self.fixed
172        #For the future
173        #return self.params[str(par_name)].is_fittable()
174
175
176    # pylint: disable=no-self-use
177    def getProfile(self):
178        """
179        Get SLD profile
180
181        : return: (z, beta) where z is a list of depth of the transition points
182                beta is a list of the corresponding SLD values
183        """
184        return None, None
185
186    def setParam(self, name, value):
187        """
188        Set the value of a model parameter
189
190        :param name: name of the parameter
191        :param value: value of the parameter
192
193        """
194        # Look for dispersion parameters
195        toks = name.split('.')
196        if len(toks) == 2:
197            for item in self.dispersion.keys():
198                if item.lower() == toks[0].lower():
199                    for par in self.dispersion[item]:
200                        if par.lower() == toks[1].lower():
201                            self.dispersion[item][par] = value
202                            return
203        else:
204            # Look for standard parameter
205            for item in self.params.keys():
206                if item.lower() == name.lower():
207                    self.params[item] = value
208                    return
209
210        raise ValueError("Model does not contain parameter %s" % name)
211
212    def getParam(self, name):
213        """
214        Set the value of a model parameter
215
216        :param name: name of the parameter
217
218        """
219        # Look for dispersion parameters
220        toks = name.split('.')
221        if len(toks) == 2:
222            for item in self.dispersion.keys():
223                if item.lower() == toks[0].lower():
224                    for par in self.dispersion[item]:
225                        if par.lower() == toks[1].lower():
226                            return self.dispersion[item][par]
227        else:
228            # Look for standard parameter
229            for item in self.params.keys():
230                if item.lower() == name.lower():
231                    return self.params[item]
232
233        raise ValueError("Model does not contain parameter %s" % name)
234
235    def getParamList(self):
236        """
237        Return a list of all available parameters for the model
238        """
239        param_list = self.params.keys()
240        # WARNING: Extending the list with the dispersion parameters
241        param_list.extend(self.getDispParamList())
242        return param_list
243
244    def getDispParamList(self):
245        """
246        Return a list of polydispersity parameters for the model
247        """
248        # TODO: fix test so that parameter order doesn't matter
249        ret = ['%s.%s' % (p.name.lower(), ext)
250               for p in self._model_info['parameters'].user_parameters()
251               for ext in ('npts', 'nsigmas', 'width')
252               if p.polydisperse]
253        #print(ret)
254        return ret
255
256    def clone(self):
257        """ Return a identical copy of self """
258        return deepcopy(self)
259
260    def run(self, x=0.0):
261        """
262        Evaluate the model
263
264        :param x: input q, or [q,phi]
265
266        :return: scattering function P(q)
267
268        **DEPRECATED**: use calculate_Iq instead
269        """
270        if isinstance(x, (list, tuple)):
271            # pylint: disable=unpacking-non-sequence
272            q, phi = x
273            return self.calculate_Iq([q * math.cos(phi)],
274                                     [q * math.sin(phi)])[0]
275        else:
276            return self.calculate_Iq([float(x)])[0]
277
278
279    def runXY(self, x=0.0):
280        """
281        Evaluate the model in cartesian coordinates
282
283        :param x: input q, or [qx, qy]
284
285        :return: scattering function P(q)
286
287        **DEPRECATED**: use calculate_Iq instead
288        """
289        if isinstance(x, (list, tuple)):
290            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
291        else:
292            return self.calculate_Iq([float(x)])[0]
293
294    def evalDistribution(self, qdist):
295        r"""
296        Evaluate a distribution of q-values.
297
298        :param qdist: array of q or a list of arrays [qx,qy]
299
300        * For 1D, a numpy array is expected as input
301
302        ::
303
304            evalDistribution(q)
305
306          where *q* is a numpy array.
307
308        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
309
310        ::
311
312              qx = [ qx[0], qx[1], qx[2], ....]
313              qy = [ qy[0], qy[1], qy[2], ....]
314
315        If the model is 1D only, then
316
317        .. math::
318
319            q = \sqrt{q_x^2+q_y^2}
320
321        """
322        if isinstance(qdist, (list, tuple)):
323            # Check whether we have a list of ndarrays [qx,qy]
324            qx, qy = qdist
325            if not self._model_info['parameters'].has_2d:
326                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
327            else:
328                return self.calculate_Iq(qx, qy)
329
330        elif isinstance(qdist, np.ndarray):
331            # We have a simple 1D distribution of q-values
332            return self.calculate_Iq(qdist)
333
334        else:
335            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
336                            % type(qdist))
337
338    def calculate_Iq(self, *args):
339        """
340        Calculate Iq for one set of q with the current parameters.
341
342        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
343
344        This should NOT be used for fitting since it copies the *q* vectors
345        to the card for each evaluation.
346        """
347        if self._model is None:
348            self._model = core.build_model(self._model_info, platform='dll')
349        q_vectors = [np.asarray(q) for q in args]
350        kernel = self._model.make_kernel(q_vectors)
351        pairs = [self._get_weights(p)
352                 for p in self._model_info['parameters'].call_parameters]
353        details, weights, values = core.build_details(kernel, pairs)
354        return kernel(details, weights, values, cutoff=self.cutoff)
355        kernel.q_input.release()
356        kernel.release()
357        return result
358
359    def calculate_ER(self):
360        """
361        Calculate the effective radius for P(q)*S(q)
362
363        :return: the value of the effective radius
364        """
365        ER = self._model_info.get('ER', None)
366        if ER is None:
367            return 1.0
368        else:
369            values, weights = self._dispersion_mesh()
370            fv = ER(*values)
371            #print(values[0].shape, weights.shape, fv.shape)
372            return np.sum(weights * fv) / np.sum(weights)
373
374    def calculate_VR(self):
375        """
376        Calculate the volf ratio for P(q)*S(q)
377
378        :return: the value of the volf ratio
379        """
380        VR = self._model_info.get('VR', None)
381        if VR is None:
382            return 1.0
383        else:
384            values, weights = self._dispersion_mesh()
385            whole, part = VR(*values)
386            return np.sum(weights * part) / np.sum(weights * whole)
387
388    def set_dispersion(self, parameter, dispersion):
389        """
390        Set the dispersion object for a model parameter
391
392        :param parameter: name of the parameter [string]
393        :param dispersion: dispersion object of type Dispersion
394        """
395        if parameter.lower() in (s.lower() for s in self.params.keys()):
396            # TODO: Store the disperser object directly in the model.
397            # The current method of creating one on the fly whenever it is
398            # needed is kind of funky.
399            # Note: can't seem to get disperser parameters from sasview
400            # (1) Could create a sasview model that has not yet # been
401            # converted, assign the disperser to one of its polydisperse
402            # parameters, then retrieve the disperser parameters from the
403            # sasview model.  (2) Could write a disperser parameter retriever
404            # in sasview.  (3) Could modify sasview to use sasmodels.weights
405            # dispersers.
406            # For now, rely on the fact that the sasview only ever uses
407            # new dispersers in the set_dispersion call and create a new
408            # one instead of trying to assign parameters.
409            from . import weights
410            disperser = weights.dispersers[dispersion.__class__.__name__]
411            dispersion = weights.models[disperser]()
412            self.dispersion[parameter] = dispersion.get_pars()
413        else:
414            raise ValueError("%r is not a dispersity or orientation parameter")
415
416    def _dispersion_mesh(self):
417        """
418        Create a mesh grid of dispersion parameters and weights.
419
420        Returns [p1,p2,...],w where pj is a vector of values for parameter j
421        and w is a vector containing the products for weights for each
422        parameter set in the vector.
423        """
424        pars = self._model_info['partype']['volume']
425        return core.dispersion_mesh([self._get_weights(p) for p in pars])
426
427    def _get_weights(self, par):
428        """
429        Return dispersion weights for parameter
430        """
431        if par.polydisperse:
432            dis = self.dispersion[par.name]
433            value, weight = weights.get_weights(
434                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
435                self.params[par.name], par.limits, par.relative_pd)
436            return value, weight / np.sum(weight)
437        else:
438            return [self.params[par.name]], []
439
440def test_model():
441    """
442    Test that a sasview model (cylinder) can be run.
443    """
444    Cylinder = _make_standard_model('cylinder')
445    cylinder = Cylinder()
446    return cylinder.evalDistribution([0.1,0.1])
447
448
449def test_model_list():
450    """
451    Make sure that all models build as sasview models.
452    """
453    from .exception import annotate_exception
454    for name in core.list_models():
455        try:
456            _make_standard_model(name)
457        except:
458            annotate_exception("when loading "+name)
459            raise
460
461if __name__ == "__main__":
462    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.