source: sasmodels/sasmodels/sasview_model.py @ a936688

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a936688 was a936688, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

update sasview wrapper so that model details are class attributes

  • Property mode set to 100644
File size: 17.3 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np
24
25from . import core
26from . import custom
27from . import generate
28
29try:
30    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
36except ImportError:
37    pass
38
39# TODO: separate x_axis_label from multiplicity info
40# The x-axis label belongs with the profile generating function
41MultiplicityInfo = collections.namedtuple(
42    'MultiplicityInfo',
43    ["number", "control", "choices", "x_axis_label"],
44)
45
46def load_standard_models():
47    """
48    Load and return the list of predefined models.
49
50    If there is an error loading a model, then a traceback is logged and the
51    model is not returned.
52    """
53    models = []
54    for name in core.list_models():
55        try:
56            models.append(_make_standard_model(name))
57        except:
58            logging.error(traceback.format_exc())
59    return models
60
61
62def load_custom_model(path):
63    """
64    Load a custom model given the model path.
65    """
66    kernel_module = custom.load_custom_kernel_module(path)
67    model_info = generate.make_model_info(kernel_module)
68    return _make_model_from_info(model_info)
69
70
71def _make_standard_model(name):
72    """
73    Load the sasview model defined by *name*.
74
75    *name* can be a standard model name or a path to a custom model.
76
77    Returns a class that can be used directly as a sasview model.
78    """
79    kernel_module = generate.load_kernel_module(name)
80    model_info = generate.make_model_info(kernel_module)
81    return _make_model_from_info(model_info)
82
83
84def _make_model_from_info(model_info):
85    """
86    Convert *model_info* into a SasView model wrapper.
87    """
88    model_info['variant_info'] = None  # temporary hack for older sasview
89    def __init__(self, multiplicity=1):
90        SasviewModel.__init__(self, multiplicity=multiplicity)
91    attrs = _generate_model_attributes(model_info)
92    attrs['__init__'] = __init__
93    ConstructedModel = type(model_info['name'], (SasviewModel,), attrs)
94    return ConstructedModel
95
96def _generate_model_attributes(model_info):
97    # type: (ModelInfo) -> Dict[str, Any]
98    """
99    Generate the class attributes for the model.
100
101    This should include all the information necessary to query the model
102    details so that you do not need to instantiate a model to query it.
103
104    All the attributes should be immutable to avoid accidents.
105    """
106    attrs = {}  # type: Dict[str, Any]
107    attrs['_model_info'] = model_info
108    attrs['name'] = model_info['name']
109    attrs['description'] = model_info['description']
110    attrs['category'] = model_info['category']
111
112    # TODO: allow model to override axis labels input/output name/unit
113
114    #self.is_multifunc = False
115    non_fittable = []  # type: List[str]
116    variants = MultiplicityInfo(0, "", [], "")
117    attrs['is_structure_factor'] = model_info['structure_factor']
118    attrs['is_form_factor'] = model_info['ER'] is not None
119    attrs['is_multiplicity_model'] = variants[0] > 1
120    attrs['multiplicity_info'] = variants
121
122    partype = model_info['partype']
123    orientation_params = (
124            partype['orientation']
125            + [n + '.width' for n in partype['orientation']]
126            + partype['magnetic'])
127    magnetic_params = partype['magnetic']
128    fixed = [n + '.width' for n in partype['pd-2d']]
129
130    attrs['orientation_params'] = tuple(orientation_params)
131    attrs['magnetic_params'] = tuple(magnetic_params)
132    attrs['fixed'] = tuple(fixed)
133
134    attrs['non_fittable'] = tuple(non_fittable)
135
136    return attrs
137
138class SasviewModel(object):
139    """
140    Sasview wrapper for opencl/ctypes model.
141    """
142    # Model parameters for the specific model are set in the class constructor
143    # via the _generate_model_attributes function, which subclasses
144    # SasviewModel.  They are included here for typing and documentation
145    # purposes.
146    _model = None       # type: KernelModel
147    _model_info = None  # type: ModelInfo
148    #: load/save name for the model
149    id = None           # type: str
150    #: display name for the model
151    name = None         # type: str
152    #: short model description
153    description = None  # type: str
154    #: default model category
155    category = None     # type: str
156
157    #: names of the orientation parameters in the order they appear
158    orientation_params = None # type: Sequence[str]
159    #: names of the magnetic parameters in the order they appear
160    magnetic_params = None    # type: Sequence[str]
161    #: names of the fittable parameters
162    fixed = None              # type: Sequence[str]
163    # TODO: the attribute fixed is ill-named
164
165    # Axis labels
166    input_name = "Q"
167    input_unit = "A^{-1}"
168    output_name = "Intensity"
169    output_unit = "cm^{-1}"
170
171    #: default cutoff for polydispersity
172    cutoff = 1e-5
173
174    # Note: Use non-mutable values for class attributes to avoid errors
175    #: parameters that are not fitted
176    non_fittable = ()        # type: Sequence[str]
177
178    #: True if model should appear as a structure factor
179    is_structure_factor = False
180    #: True if model should appear as a form factor
181    is_form_factor = False
182    #: True if model has multiplicity
183    is_multiplicity_model = False
184    #: Mulitplicity information
185    multiplicity_info = None # type: MultiplicityInfoType
186
187    # Per-instance variables
188    #: parameter {name: value} mapping
189    params = None      # type: Dict[str, float]
190    #: values for dispersion width, npts, nsigmas and type
191    dispersion = None  # type: Dict[str, Any]
192    #: units and limits for each parameter
193    details = None     # type: Mapping[str, Tuple(str, float, float)]
194    #: multiplicity used, or None if no multiplicity controls
195    multiplicity = None     # type: Optional[int]
196
197    def __init__(self, multiplicity):
198        # type: () -> None
199        print("initializing", self.name)
200        #raise Exception("first initialization")
201        self._model = None
202
203        ## _persistency_dict is used by sas.perspectives.fitting.basepage
204        ## to store dispersity reference.
205        self._persistency_dict = {}
206
207        self.multiplicity = multiplicity
208
209        self.params = collections.OrderedDict()
210        self.dispersion = {}
211        self.details = {}
212
213        for p in self._model_info['parameters']:
214            self.params[p.name] = p.default
215            self.details[p.name] = [p.units] + p.limits
216
217        for name in self._model_info['partype']['pd-2d']:
218            self.dispersion[name] = {
219                'width': 0,
220                'npts': 35,
221                'nsigmas': 3,
222                'type': 'gaussian',
223            }
224
225    def __get_state__(self):
226        state = self.__dict__.copy()
227        state.pop('_model')
228        # May need to reload model info on set state since it has pointers
229        # to python implementations of Iq, etc.
230        #state.pop('_model_info')
231        return state
232
233    def __set_state__(self, state):
234        self.__dict__ = state
235        self._model = None
236
237    def __str__(self):
238        """
239        :return: string representation
240        """
241        return self.name
242
243    def is_fittable(self, par_name):
244        """
245        Check if a given parameter is fittable or not
246
247        :param par_name: the parameter name to check
248        """
249        return par_name.lower() in self.fixed
250        #For the future
251        #return self.params[str(par_name)].is_fittable()
252
253
254    # pylint: disable=no-self-use
255    def getProfile(self):
256        """
257        Get SLD profile
258
259        : return: (z, beta) where z is a list of depth of the transition points
260                beta is a list of the corresponding SLD values
261        """
262        return None, None
263
264    def setParam(self, name, value):
265        """
266        Set the value of a model parameter
267
268        :param name: name of the parameter
269        :param value: value of the parameter
270
271        """
272        # Look for dispersion parameters
273        toks = name.split('.')
274        if len(toks) == 2:
275            for item in self.dispersion.keys():
276                if item.lower() == toks[0].lower():
277                    for par in self.dispersion[item]:
278                        if par.lower() == toks[1].lower():
279                            self.dispersion[item][par] = value
280                            return
281        else:
282            # Look for standard parameter
283            for item in self.params.keys():
284                if item.lower() == name.lower():
285                    self.params[item] = value
286                    return
287
288        raise ValueError("Model does not contain parameter %s" % name)
289
290    def getParam(self, name):
291        """
292        Set the value of a model parameter
293
294        :param name: name of the parameter
295
296        """
297        # Look for dispersion parameters
298        toks = name.split('.')
299        if len(toks) == 2:
300            for item in self.dispersion.keys():
301                if item.lower() == toks[0].lower():
302                    for par in self.dispersion[item]:
303                        if par.lower() == toks[1].lower():
304                            return self.dispersion[item][par]
305        else:
306            # Look for standard parameter
307            for item in self.params.keys():
308                if item.lower() == name.lower():
309                    return self.params[item]
310
311        raise ValueError("Model does not contain parameter %s" % name)
312
313    def getParamList(self):
314        """
315        Return a list of all available parameters for the model
316        """
317        param_list = self.params.keys()
318        # WARNING: Extending the list with the dispersion parameters
319        param_list.extend(self.getDispParamList())
320        return param_list
321
322    def getDispParamList(self):
323        """
324        Return a list of polydispersity parameters for the model
325        """
326        # TODO: fix test so that parameter order doesn't matter
327        ret = ['%s.%s' % (d.lower(), p)
328               for d in self._model_info['partype']['pd-2d']
329               for p in ('npts', 'nsigmas', 'width')]
330        #print(ret)
331        return ret
332
333    def clone(self):
334        """ Return a identical copy of self """
335        return deepcopy(self)
336
337    def run(self, x=0.0):
338        """
339        Evaluate the model
340
341        :param x: input q, or [q,phi]
342
343        :return: scattering function P(q)
344
345        **DEPRECATED**: use calculate_Iq instead
346        """
347        if isinstance(x, (list, tuple)):
348            # pylint: disable=unpacking-non-sequence
349            q, phi = x
350            return self.calculate_Iq([q * math.cos(phi)],
351                                     [q * math.sin(phi)])[0]
352        else:
353            return self.calculate_Iq([float(x)])[0]
354
355
356    def runXY(self, x=0.0):
357        """
358        Evaluate the model in cartesian coordinates
359
360        :param x: input q, or [qx, qy]
361
362        :return: scattering function P(q)
363
364        **DEPRECATED**: use calculate_Iq instead
365        """
366        if isinstance(x, (list, tuple)):
367            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
368        else:
369            return self.calculate_Iq([float(x)])[0]
370
371    def evalDistribution(self, qdist):
372        r"""
373        Evaluate a distribution of q-values.
374
375        :param qdist: array of q or a list of arrays [qx,qy]
376
377        * For 1D, a numpy array is expected as input
378
379        ::
380
381            evalDistribution(q)
382
383          where *q* is a numpy array.
384
385        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
386
387        ::
388
389              qx = [ qx[0], qx[1], qx[2], ....]
390              qy = [ qy[0], qy[1], qy[2], ....]
391
392        If the model is 1D only, then
393
394        .. math::
395
396            q = \sqrt{q_x^2+q_y^2}
397
398        """
399        if isinstance(qdist, (list, tuple)):
400            # Check whether we have a list of ndarrays [qx,qy]
401            qx, qy = qdist
402            partype = self._model_info['partype']
403            if not partype['orientation'] and not partype['magnetic']:
404                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
405            else:
406                return self.calculate_Iq(qx, qy)
407
408        elif isinstance(qdist, np.ndarray):
409            # We have a simple 1D distribution of q-values
410            return self.calculate_Iq(qdist)
411
412        else:
413            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
414                            % type(qdist))
415
416    def calculate_Iq(self, *args):
417        """
418        Calculate Iq for one set of q with the current parameters.
419
420        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
421
422        This should NOT be used for fitting since it copies the *q* vectors
423        to the card for each evaluation.
424        """
425        if self._model is None:
426            self._model = core.build_model(self._model_info)
427        q_vectors = [np.asarray(q) for q in args]
428        fn = self._model.make_kernel(q_vectors)
429        pars = [self.params[v] for v in fn.fixed_pars]
430        pd_pars = [self._get_weights(p) for p in fn.pd_pars]
431        result = fn(pars, pd_pars, self.cutoff)
432        fn.q_input.release()
433        fn.release()
434        return result
435
436    def calculate_ER(self):
437        """
438        Calculate the effective radius for P(q)*S(q)
439
440        :return: the value of the effective radius
441        """
442        ER = self._model_info.get('ER', None)
443        if ER is None:
444            return 1.0
445        else:
446            values, weights = self._dispersion_mesh()
447            fv = ER(*values)
448            #print(values[0].shape, weights.shape, fv.shape)
449            return np.sum(weights * fv) / np.sum(weights)
450
451    def calculate_VR(self):
452        """
453        Calculate the volf ratio for P(q)*S(q)
454
455        :return: the value of the volf ratio
456        """
457        VR = self._model_info.get('VR', None)
458        if VR is None:
459            return 1.0
460        else:
461            values, weights = self._dispersion_mesh()
462            whole, part = VR(*values)
463            return np.sum(weights * part) / np.sum(weights * whole)
464
465    def set_dispersion(self, parameter, dispersion):
466        """
467        Set the dispersion object for a model parameter
468
469        :param parameter: name of the parameter [string]
470        :param dispersion: dispersion object of type Dispersion
471        """
472        if parameter.lower() in (s.lower() for s in self.params.keys()):
473            # TODO: Store the disperser object directly in the model.
474            # The current method of creating one on the fly whenever it is
475            # needed is kind of funky.
476            # Note: can't seem to get disperser parameters from sasview
477            # (1) Could create a sasview model that has not yet # been
478            # converted, assign the disperser to one of its polydisperse
479            # parameters, then retrieve the disperser parameters from the
480            # sasview model.  (2) Could write a disperser parameter retriever
481            # in sasview.  (3) Could modify sasview to use sasmodels.weights
482            # dispersers.
483            # For now, rely on the fact that the sasview only ever uses
484            # new dispersers in the set_dispersion call and create a new
485            # one instead of trying to assign parameters.
486            from . import weights
487            disperser = weights.dispersers[dispersion.__class__.__name__]
488            dispersion = weights.models[disperser]()
489            self.dispersion[parameter] = dispersion.get_pars()
490        else:
491            raise ValueError("%r is not a dispersity or orientation parameter")
492
493    def _dispersion_mesh(self):
494        """
495        Create a mesh grid of dispersion parameters and weights.
496
497        Returns [p1,p2,...],w where pj is a vector of values for parameter j
498        and w is a vector containing the products for weights for each
499        parameter set in the vector.
500        """
501        pars = self._model_info['partype']['volume']
502        return core.dispersion_mesh([self._get_weights(p) for p in pars])
503
504    def _get_weights(self, par):
505        """
506        Return dispersion weights for parameter
507        """
508        from . import weights
509        relative = self._model_info['partype']['pd-rel']
510        limits = self._model_info['limits']
511        dis = self.dispersion[par]
512        value, weight = weights.get_weights(
513            dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
514            self.params[par], limits[par], par in relative)
515        return value, weight / np.sum(weight)
516
517
518def test_model():
519    """
520    Test that a sasview model (cylinder) can be run.
521    """
522    Cylinder = _make_standard_model('cylinder')
523    cylinder = Cylinder()
524    return cylinder.evalDistribution([0.1,0.1])
525
526
527def test_model_list():
528    """
529    Make sure that all models build as sasview models.
530    """
531    from .exception import annotate_exception
532    for name in core.list_models():
533        try:
534            _make_standard_model(name)
535        except:
536            annotate_exception("when loading "+name)
537            raise
538
539if __name__ == "__main__":
540    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.