source: sasmodels/sasmodels/sasview_model.py @ 0d99a6a

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0d99a6a was 0d99a6a, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

doc strings for modelinfo

  • Property mode set to 100644
File size: 14.9 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29from . import details
30from . import modelinfo
31
32def load_standard_models():
33    """
34    Load and return the list of predefined models.
35
36    If there is an error loading a model, then a traceback is logged and the
37    model is not returned.
38    """
39    models = []
40    for name in core.list_models():
41        try:
42            models.append(_make_standard_model(name))
43        except Exception:
44            logging.error(traceback.format_exc())
45    return models
46
47
48def load_custom_model(path):
49    """
50    Load a custom model given the model path.
51    """
52    kernel_module = custom.load_custom_kernel_module(path)
53    model_info = modelinfo.make_model_info(kernel_module)
54    return _make_model_from_info(model_info)
55
56
57def _make_standard_model(name):
58    """
59    Load the sasview model defined by *name*.
60
61    *name* can be a standard model name or a path to a custom model.
62
63    Returns a class that can be used directly as a sasview model.
64    """
65    kernel_module = generate.load_kernel_module(name)
66    #model_info = modelinfo.make_model_info(kernel_module)
67    model_info = modelinfo.make_model_info("hello")
68    return _make_model_from_info(model_info)
69
70
71def _make_model_from_info(model_info):
72    """
73    Convert *model_info* into a SasView model wrapper.
74    """
75    def __init__(self, multfactor=1):
76        SasviewModel.__init__(self)
77    attrs = dict(__init__=__init__, _model_info=model_info)
78    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
79    return ConstructedModel
80
81
82class SasviewModel(object):
83    """
84    Sasview wrapper for opencl/ctypes model.
85    """
86    _model_info = None # type: modelinfo.ModelInfo
87    def __init__(self):
88        self._model = None
89        model_info = self._model_info
90        parameters = model_info.parameters
91
92        self.name = model_info.name
93        self.description = model_info.description
94        self.category = None
95        #self.is_multifunc = False
96        for p in parameters.kernel_parameters:
97            if p.is_control:
98                profile_axes = model_info.profile_axes
99                self.multiplicity_info = [
100                    p.limits[1], p.name, p.choices, profile_axes[0]
101                    ]
102                break
103        else:
104            self.multiplicity_info = []
105
106        ## interpret the parameters
107        ## TODO: reorganize parameter handling
108        self.details = dict()
109        self.params = collections.OrderedDict()
110        self.dispersion = dict()
111
112        self.orientation_params = []
113        self.magnetic_params = []
114        self.fixed = []
115        for p in parameters.user_parameters():
116            self.params[p.name] = p.default
117            self.details[p.name] = [p.units] + p.limits
118            if p.polydisperse:
119                self.dispersion[p.name] = {
120                    'width': 0,
121                    'npts': 35,
122                    'nsigmas': 3,
123                    'type': 'gaussian',
124                }
125            if p.type == 'orientation':
126                self.orientation_params.append(p.name)
127                self.orientation_params.append(p.name+".width")
128                self.fixed.append(p.name+".width")
129            if p.type == 'magnetic':
130                self.orientation_params.append(p.name)
131                self.magnetic_params.append(p.name)
132                self.fixed.append(p.name+".width")
133
134        self.non_fittable = []
135
136        ## independent parameter name and unit [string]
137        self.input_name = "Q", #model_info.get("input_name", "Q")
138        self.input_unit = "A^{-1}" #model_info.get("input_unit", "A^{-1}")
139        self.output_name = "Intensity" #model_info.get("output_name", "Intensity")
140        self.output_unit = "cm^{-1}" #model_info.get("output_unit", "cm^{-1}")
141
142        ## _persistency_dict is used by sas.perspectives.fitting.basepage
143        ## to store dispersity reference.
144        ## TODO: _persistency_dict to persistency_dict throughout sasview
145        self._persistency_dict = {}
146
147        ## New fields introduced for opencl rewrite
148        self.cutoff = 1e-5
149
150    def __get_state__(self):
151        state = self.__dict__.copy()
152        state.pop('_model')
153        # May need to reload model info on set state since it has pointers
154        # to python implementations of Iq, etc.
155        #state.pop('_model_info')
156        return state
157
158    def __set_state__(self, state):
159        self.__dict__ = state
160        self._model = None
161
162    def __str__(self):
163        """
164        :return: string representation
165        """
166        return self.name
167
168    def is_fittable(self, par_name):
169        """
170        Check if a given parameter is fittable or not
171
172        :param par_name: the parameter name to check
173        """
174        return par_name.lower() in self.fixed
175        #For the future
176        #return self.params[str(par_name)].is_fittable()
177
178
179    # pylint: disable=no-self-use
180    def getProfile(self):
181        """
182        Get SLD profile
183
184        : return: (z, beta) where z is a list of depth of the transition points
185                beta is a list of the corresponding SLD values
186        """
187        return None, None
188
189    def setParam(self, name, value):
190        """
191        Set the value of a model parameter
192
193        :param name: name of the parameter
194        :param value: value of the parameter
195
196        """
197        # Look for dispersion parameters
198        toks = name.split('.')
199        if len(toks) == 2:
200            for item in self.dispersion.keys():
201                if item.lower() == toks[0].lower():
202                    for par in self.dispersion[item]:
203                        if par.lower() == toks[1].lower():
204                            self.dispersion[item][par] = value
205                            return
206        else:
207            # Look for standard parameter
208            for item in self.params.keys():
209                if item.lower() == name.lower():
210                    self.params[item] = value
211                    return
212
213        raise ValueError("Model does not contain parameter %s" % name)
214
215    def getParam(self, name):
216        """
217        Set the value of a model parameter
218
219        :param name: name of the parameter
220
221        """
222        # Look for dispersion parameters
223        toks = name.split('.')
224        if len(toks) == 2:
225            for item in self.dispersion.keys():
226                if item.lower() == toks[0].lower():
227                    for par in self.dispersion[item]:
228                        if par.lower() == toks[1].lower():
229                            return self.dispersion[item][par]
230        else:
231            # Look for standard parameter
232            for item in self.params.keys():
233                if item.lower() == name.lower():
234                    return self.params[item]
235
236        raise ValueError("Model does not contain parameter %s" % name)
237
238    def getParamList(self):
239        """
240        Return a list of all available parameters for the model
241        """
242        param_list = self.params.keys()
243        # WARNING: Extending the list with the dispersion parameters
244        param_list.extend(self.getDispParamList())
245        return param_list
246
247    def getDispParamList(self):
248        """
249        Return a list of polydispersity parameters for the model
250        """
251        # TODO: fix test so that parameter order doesn't matter
252        ret = ['%s.%s' % (p.name.lower(), ext)
253               for p in self._model_info.parameters.user_parameters()
254               for ext in ('npts', 'nsigmas', 'width')
255               if p.polydisperse]
256        #print(ret)
257        return ret
258
259    def clone(self):
260        """ Return a identical copy of self """
261        return deepcopy(self)
262
263    def run(self, x=0.0):
264        """
265        Evaluate the model
266
267        :param x: input q, or [q,phi]
268
269        :return: scattering function P(q)
270
271        **DEPRECATED**: use calculate_Iq instead
272        """
273        if isinstance(x, (list, tuple)):
274            # pylint: disable=unpacking-non-sequence
275            q, phi = x
276            return self.calculate_Iq([q * math.cos(phi)],
277                                     [q * math.sin(phi)])[0]
278        else:
279            return self.calculate_Iq([float(x)])[0]
280
281
282    def runXY(self, x=0.0):
283        """
284        Evaluate the model in cartesian coordinates
285
286        :param x: input q, or [qx, qy]
287
288        :return: scattering function P(q)
289
290        **DEPRECATED**: use calculate_Iq instead
291        """
292        if isinstance(x, (list, tuple)):
293            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
294        else:
295            return self.calculate_Iq([float(x)])[0]
296
297    def evalDistribution(self, qdist):
298        r"""
299        Evaluate a distribution of q-values.
300
301        :param qdist: array of q or a list of arrays [qx,qy]
302
303        * For 1D, a numpy array is expected as input
304
305        ::
306
307            evalDistribution(q)
308
309          where *q* is a numpy array.
310
311        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
312
313        ::
314
315              qx = [ qx[0], qx[1], qx[2], ....]
316              qy = [ qy[0], qy[1], qy[2], ....]
317
318        If the model is 1D only, then
319
320        .. math::
321
322            q = \sqrt{q_x^2+q_y^2}
323
324        """
325        if isinstance(qdist, (list, tuple)):
326            # Check whether we have a list of ndarrays [qx,qy]
327            qx, qy = qdist
328            if not self._model_info.parameters.has_2d:
329                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
330            else:
331                return self.calculate_Iq(qx, qy)
332
333        elif isinstance(qdist, np.ndarray):
334            # We have a simple 1D distribution of q-values
335            return self.calculate_Iq(qdist)
336
337        else:
338            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
339                            % type(qdist))
340
341    def calculate_Iq(self, *args):
342        """
343        Calculate Iq for one set of q with the current parameters.
344
345        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
346
347        This should NOT be used for fitting since it copies the *q* vectors
348        to the card for each evaluation.
349        """
350        if self._model is None:
351            self._model = core.build_model(self._model_info)
352        q_vectors = [np.asarray(q) for q in args]
353        kernel = self._model.make_kernel(q_vectors)
354        pairs = [self._get_weights(p)
355                 for p in self._model_info.parameters.call_parameters]
356        call_details, weight, value = details.build_details(kernel, pairs)
357        result = kernel(call_details, weight, value, cutoff=self.cutoff)
358        kernel.q_input.release()
359        kernel.release()
360        return result
361
362    def calculate_ER(self):
363        """
364        Calculate the effective radius for P(q)*S(q)
365
366        :return: the value of the effective radius
367        """
368        if self._model_info.ER is None:
369            return 1.0
370        else:
371            value, weight = self._dispersion_mesh()
372            fv = self._model_info.ER(*value)
373            #print(values[0].shape, weights.shape, fv.shape)
374            return np.sum(weight * fv) / np.sum(weight)
375
376    def calculate_VR(self):
377        """
378        Calculate the volf ratio for P(q)*S(q)
379
380        :return: the value of the volf ratio
381        """
382        if self._model_info.VR is None:
383            return 1.0
384        else:
385            value, weight = self._dispersion_mesh()
386            whole, part = self._model_info.VR(*value)
387            return np.sum(weight * part) / np.sum(weight * whole)
388
389    def set_dispersion(self, parameter, dispersion):
390        """
391        Set the dispersion object for a model parameter
392
393        :param parameter: name of the parameter [string]
394        :param dispersion: dispersion object of type Dispersion
395        """
396        if parameter.lower() in (s.lower() for s in self.params.keys()):
397            # TODO: Store the disperser object directly in the model.
398            # The current method of creating one on the fly whenever it is
399            # needed is kind of funky.
400            # Note: can't seem to get disperser parameters from sasview
401            # (1) Could create a sasview model that has not yet # been
402            # converted, assign the disperser to one of its polydisperse
403            # parameters, then retrieve the disperser parameters from the
404            # sasview model.  (2) Could write a disperser parameter retriever
405            # in sasview.  (3) Could modify sasview to use sasmodels.weights
406            # dispersers.
407            # For now, rely on the fact that the sasview only ever uses
408            # new dispersers in the set_dispersion call and create a new
409            # one instead of trying to assign parameters.
410            from . import weights
411            disperser = weights.dispersers[dispersion.__class__.__name__]
412            dispersion = weights.models[disperser]()
413            self.dispersion[parameter] = dispersion.get_pars()
414        else:
415            raise ValueError("%r is not a dispersity or orientation parameter")
416
417    def _dispersion_mesh(self):
418        """
419        Create a mesh grid of dispersion parameters and weights.
420
421        Returns [p1,p2,...],w where pj is a vector of values for parameter j
422        and w is a vector containing the products for weights for each
423        parameter set in the vector.
424        """
425        pars = [self._get_weights(p)
426                for p in self._model_info.parameters.call_parameters
427                if p.type == 'volume']
428        return details.dispersion_mesh(self._model_info, pars)
429
430    def _get_weights(self, par):
431        """
432        Return dispersion weights for parameter
433        """
434        if par.polydisperse:
435            dis = self.dispersion[par.name]
436            value, weight = weights.get_weights(
437                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
438                self.params[par.name], par.limits, par.relative_pd)
439            return value, weight / np.sum(weight)
440        else:
441            return [self.params[par.name]], []
442
443def test_model():
444    """
445    Test that a sasview model (cylinder) can be run.
446    """
447    Cylinder = _make_standard_model('cylinder')
448    cylinder = Cylinder()
449    return cylinder.evalDistribution([0.1,0.1])
450
451
452def test_model_list():
453    """
454    Make sure that all models build as sasview models.
455    """
456    from .exception import annotate_exception
457    for name in core.list_models():
458        try:
459            _make_standard_model(name)
460        except:
461            annotate_exception("when loading "+name)
462            raise
463
464if __name__ == "__main__":
465    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.