source: sasmodels/sasmodels/sasview_model.py @ 04dc697

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 04dc697 was 04dc697, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

more type hinting

  • Property mode set to 100644
File size: 21.4 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29from . import details
30from . import modelinfo
31
32try:
33    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union
34    from .modelinfo import ModelInfo, Parameter
35    from .kernel import KernelModel
36    MultiplicityInfoType = NamedTuple(
37        'MuliplicityInfo',
38        [("number", int), ("control", str), ("choices", List[str]),
39         ("x_axis_label", str)])
40except ImportError:
41    pass
42
43# TODO: separate x_axis_label from multiplicity info
44# The profile x-axis label belongs with the profile generating function
45MultiplicityInfo = collections.namedtuple(
46    'MultiplicityInfo',
47    ["number", "control", "choices", "x_axis_label"],
48)
49
50# TODO: figure out how to say that the return type is a subclass
51def load_standard_models():
52    # type: () -> List[type]
53    """
54    Load and return the list of predefined models.
55
56    If there is an error loading a model, then a traceback is logged and the
57    model is not returned.
58    """
59    models = []
60    for name in core.list_models():
61        try:
62            models.append(_make_standard_model(name))
63        except Exception:
64            logging.error(traceback.format_exc())
65    return models
66
67
68def load_custom_model(path):
69    # type: (str) -> type
70    """
71    Load a custom model given the model path.
72    """
73    kernel_module = custom.load_custom_kernel_module(path)
74    model_info = modelinfo.make_model_info(kernel_module)
75    return _make_model_from_info(model_info)
76
77
78def _make_standard_model(name):
79    # type: (str) -> type
80    """
81    Load the sasview model defined by *name*.
82
83    *name* can be a standard model name or a path to a custom model.
84
85    Returns a class that can be used directly as a sasview model.
86    """
87    kernel_module = generate.load_kernel_module(name)
88    model_info = modelinfo.make_model_info(kernel_module)
89    return _make_model_from_info(model_info)
90
91
92def _make_model_from_info(model_info):
93    # type: (ModelInfo) -> type
94    """
95    Convert *model_info* into a SasView model wrapper.
96    """
97    def __init__(self, multiplicity=None):
98        SasviewModel.__init__(self, multiplicity=multiplicity)
99    attrs = _generate_model_attributes(model_info)
100    attrs['__init__'] = __init__
101    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
102    return ConstructedModel
103
104def _generate_model_attributes(model_info):
105    # type: (ModelInfo) -> Dict[str, Any]
106    """
107    Generate the class attributes for the model.
108
109    This should include all the information necessary to query the model
110    details so that you do not need to instantiate a model to query it.
111
112    All the attributes should be immutable to avoid accidents.
113    """
114
115    # TODO: allow model to override axis labels input/output name/unit
116
117    # Process multiplicity
118    non_fittable = []  # type: List[str]
119    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
120    variants = MultiplicityInfo(0, "", [], xlabel)
121    for p in model_info.parameters.kernel_parameters:
122        if p.name == model_info.control:
123            non_fittable.append(p.name)
124            variants = MultiplicityInfo(
125                len(p.choices), p.name, p.choices, xlabel
126            )
127            break
128        elif p.is_control:
129            non_fittable.append(p.name)
130            variants = MultiplicityInfo(
131                int(p.limits[1]), p.name, p.choices, xlabel
132            )
133            break
134
135    # Organize parameter sets
136    orientation_params = []
137    magnetic_params = []
138    fixed = []
139    for p in model_info.parameters.user_parameters():
140        if p.type == 'orientation':
141            orientation_params.append(p.name)
142            orientation_params.append(p.name+".width")
143            fixed.append(p.name+".width")
144        if p.type == 'magnetic':
145            orientation_params.append(p.name)
146            magnetic_params.append(p.name)
147            fixed.append(p.name+".width")
148
149    # Build class dictionary
150    attrs = {}  # type: Dict[str, Any]
151    attrs['_model_info'] = model_info
152    attrs['name'] = model_info.name
153    attrs['id'] = model_info.id
154    attrs['description'] = model_info.description
155    attrs['category'] = model_info.category
156    attrs['is_structure_factor'] = model_info.structure_factor
157    attrs['is_form_factor'] = model_info.ER is not None
158    attrs['is_multiplicity_model'] = variants[0] > 1
159    attrs['multiplicity_info'] = variants
160    attrs['orientation_params'] = tuple(orientation_params)
161    attrs['magnetic_params'] = tuple(magnetic_params)
162    attrs['fixed'] = tuple(fixed)
163    attrs['non_fittable'] = tuple(non_fittable)
164
165    return attrs
166
167class SasviewModel(object):
168    """
169    Sasview wrapper for opencl/ctypes model.
170    """
171    # Model parameters for the specific model are set in the class constructor
172    # via the _generate_model_attributes function, which subclasses
173    # SasviewModel.  They are included here for typing and documentation
174    # purposes.
175    _model = None       # type: KernelModel
176    _model_info = None  # type: ModelInfo
177    #: load/save name for the model
178    id = None           # type: str
179    #: display name for the model
180    name = None         # type: str
181    #: short model description
182    description = None  # type: str
183    #: default model category
184    category = None     # type: str
185
186    #: names of the orientation parameters in the order they appear
187    orientation_params = None # type: Sequence[str]
188    #: names of the magnetic parameters in the order they appear
189    magnetic_params = None    # type: Sequence[str]
190    #: names of the fittable parameters
191    fixed = None              # type: Sequence[str]
192    # TODO: the attribute fixed is ill-named
193
194    # Axis labels
195    input_name = "Q"
196    input_unit = "A^{-1}"
197    output_name = "Intensity"
198    output_unit = "cm^{-1}"
199
200    #: default cutoff for polydispersity
201    cutoff = 1e-5
202
203    # Note: Use non-mutable values for class attributes to avoid errors
204    #: parameters that are not fitted
205    non_fittable = ()        # type: Sequence[str]
206
207    #: True if model should appear as a structure factor
208    is_structure_factor = False
209    #: True if model should appear as a form factor
210    is_form_factor = False
211    #: True if model has multiplicity
212    is_multiplicity_model = False
213    #: Mulitplicity information
214    multiplicity_info = None # type: MultiplicityInfoType
215
216    # Per-instance variables
217    #: parameter {name: value} mapping
218    params = None      # type: Dict[str, float]
219    #: values for dispersion width, npts, nsigmas and type
220    dispersion = None  # type: Dict[str, Any]
221    #: units and limits for each parameter
222    details = None     # type: Mapping[str, Tuple[str, float, float]]
223    #: multiplicity value, or None if no multiplicity on the model
224    multiplicity = None     # type: Optional[int]
225    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
226    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
227
228    def __init__(self, multiplicity=None):
229        # type: (Optional[int]) -> None
230
231        # TODO: _persistency_dict to persistency_dict throughout sasview
232        # TODO: refactor multiplicity to encompass variants
233        # TODO: dispersion should be a class
234        # TODO: refactor multiplicity info
235        # TODO: separate profile view from multiplicity
236        # The button label, x and y axis labels and scale need to be under
237        # the control of the model, not the fit page.  Maximum flexibility,
238        # the fit page would supply the canvas and the profile could plot
239        # how it wants, but this assumes matplotlib.  Next level is that
240        # we provide some sort of data description including title, labels
241        # and lines to plot.
242
243        # Get the list of hidden parameters given the mulitplicity
244        # Don't include multiplicity in the list of parameters
245        self.multiplicity = multiplicity
246        if multiplicity is not None:
247            hidden = self._model_info.get_hidden_parameters(multiplicity)
248            hidden |= set([self.multiplicity_info.control])
249        else:
250            hidden = set()
251
252        self._persistency_dict = {}
253        self.params = collections.OrderedDict()
254        self.dispersion = {}
255        self.details = {}
256        for p in self._model_info.parameters.user_parameters():
257            if p.name in hidden:
258                continue
259            self.params[p.name] = p.default
260            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
261            if p.polydisperse:
262                self.details[p.id+".width"] = [
263                    "", 0.0, 1.0 if p.relative_pd else np.inf
264                ]
265                self.dispersion[p.name] = {
266                    'width': 0,
267                    'npts': 35,
268                    'nsigmas': 3,
269                    'type': 'gaussian',
270                }
271
272    def __get_state__(self):
273        # type: () -> Dict[str, Any]
274        state = self.__dict__.copy()
275        state.pop('_model')
276        # May need to reload model info on set state since it has pointers
277        # to python implementations of Iq, etc.
278        #state.pop('_model_info')
279        return state
280
281    def __set_state__(self, state):
282        # type: (Dict[str, Any]) -> None
283        self.__dict__ = state
284        self._model = None
285
286    def __str__(self):
287        # type: () -> str
288        """
289        :return: string representation
290        """
291        return self.name
292
293    def is_fittable(self, par_name):
294        # type: (str) -> bool
295        """
296        Check if a given parameter is fittable or not
297
298        :param par_name: the parameter name to check
299        """
300        return par_name.lower() in self.fixed
301        #For the future
302        #return self.params[str(par_name)].is_fittable()
303
304
305    def getProfile(self):
306        # type: () -> (np.ndarray, np.ndarray)
307        """
308        Get SLD profile
309
310        : return: (z, beta) where z is a list of depth of the transition points
311                beta is a list of the corresponding SLD values
312        """
313        args = []
314        for p in self._model_info.parameters.kernel_parameters:
315            if p.id == self.multiplicity_info.control:
316                args.append(self.multiplicity)
317            elif p.length == 1:
318                args.append(self.params.get(p.id, np.NaN))
319            else:
320                args.append([self.params.get(p.id+str(k), np.NaN)
321                             for k in range(1,p.length+1)])
322        return self._model_info.profile(*args)
323
324    def setParam(self, name, value):
325        # type: (str, float) -> None
326        """
327        Set the value of a model parameter
328
329        :param name: name of the parameter
330        :param value: value of the parameter
331
332        """
333        # Look for dispersion parameters
334        toks = name.split('.')
335        if len(toks) == 2:
336            for item in self.dispersion.keys():
337                if item.lower() == toks[0].lower():
338                    for par in self.dispersion[item]:
339                        if par.lower() == toks[1].lower():
340                            self.dispersion[item][par] = value
341                            return
342        else:
343            # Look for standard parameter
344            for item in self.params.keys():
345                if item.lower() == name.lower():
346                    self.params[item] = value
347                    return
348
349        raise ValueError("Model does not contain parameter %s" % name)
350
351    def getParam(self, name):
352        # type: (str) -> float
353        """
354        Set the value of a model parameter
355
356        :param name: name of the parameter
357
358        """
359        # Look for dispersion parameters
360        toks = name.split('.')
361        if len(toks) == 2:
362            for item in self.dispersion.keys():
363                if item.lower() == toks[0].lower():
364                    for par in self.dispersion[item]:
365                        if par.lower() == toks[1].lower():
366                            return self.dispersion[item][par]
367        else:
368            # Look for standard parameter
369            for item in self.params.keys():
370                if item.lower() == name.lower():
371                    return self.params[item]
372
373        raise ValueError("Model does not contain parameter %s" % name)
374
375    def getParamList(self):
376        # type: () -> Sequence[str]
377        """
378        Return a list of all available parameters for the model
379        """
380        param_list = list(self.params.keys())
381        # WARNING: Extending the list with the dispersion parameters
382        param_list.extend(self.getDispParamList())
383        return param_list
384
385    def getDispParamList(self):
386        # type: () -> Sequence[str]
387        """
388        Return a list of polydispersity parameters for the model
389        """
390        # TODO: fix test so that parameter order doesn't matter
391        ret = ['%s.%s' % (p.name.lower(), ext)
392               for p in self._model_info.parameters.user_parameters()
393               for ext in ('npts', 'nsigmas', 'width')
394               if p.polydisperse]
395        #print(ret)
396        return ret
397
398    def clone(self):
399        # type: () -> "SasviewModel"
400        """ Return a identical copy of self """
401        return deepcopy(self)
402
403    def run(self, x=0.0):
404        # type: (Union[float, (float, float), List[float]]) -> float
405        """
406        Evaluate the model
407
408        :param x: input q, or [q,phi]
409
410        :return: scattering function P(q)
411
412        **DEPRECATED**: use calculate_Iq instead
413        """
414        if isinstance(x, (list, tuple)):
415            # pylint: disable=unpacking-non-sequence
416            q, phi = x
417            return self.calculate_Iq([q * math.cos(phi)],
418                                     [q * math.sin(phi)])[0]
419        else:
420            return self.calculate_Iq([float(x)])[0]
421
422
423    def runXY(self, x=0.0):
424        # type: (Union[float, (float, float), List[float]]) -> float
425        """
426        Evaluate the model in cartesian coordinates
427
428        :param x: input q, or [qx, qy]
429
430        :return: scattering function P(q)
431
432        **DEPRECATED**: use calculate_Iq instead
433        """
434        if isinstance(x, (list, tuple)):
435            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
436        else:
437            return self.calculate_Iq([float(x)])[0]
438
439    def evalDistribution(self, qdist):
440        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
441        r"""
442        Evaluate a distribution of q-values.
443
444        :param qdist: array of q or a list of arrays [qx,qy]
445
446        * For 1D, a numpy array is expected as input
447
448        ::
449
450            evalDistribution(q)
451
452          where *q* is a numpy array.
453
454        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
455
456        ::
457
458              qx = [ qx[0], qx[1], qx[2], ....]
459              qy = [ qy[0], qy[1], qy[2], ....]
460
461        If the model is 1D only, then
462
463        .. math::
464
465            q = \sqrt{q_x^2+q_y^2}
466
467        """
468        if isinstance(qdist, (list, tuple)):
469            # Check whether we have a list of ndarrays [qx,qy]
470            qx, qy = qdist
471            if not self._model_info.parameters.has_2d:
472                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
473            else:
474                return self.calculate_Iq(qx, qy)
475
476        elif isinstance(qdist, np.ndarray):
477            # We have a simple 1D distribution of q-values
478            return self.calculate_Iq(qdist)
479
480        else:
481            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
482                            % type(qdist))
483
484    def calculate_Iq(self, qx, qy=None):
485        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
486        """
487        Calculate Iq for one set of q with the current parameters.
488
489        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
490
491        This should NOT be used for fitting since it copies the *q* vectors
492        to the card for each evaluation.
493        """
494        if self._model is None:
495            self._model = core.build_model(self._model_info)
496        if qy is not None:
497            q_vectors = [np.asarray(qx), np.asarray(qy)]
498        else:
499            q_vectors = [np.asarray(qx)]
500        kernel = self._model.make_kernel(q_vectors)
501        pairs = [self._get_weights(p)
502                 for p in self._model_info.parameters.call_parameters]
503        call_details, weight, value = details.build_details(kernel, pairs)
504        result = kernel(call_details, weight, value, cutoff=self.cutoff)
505        kernel.release()
506        return result
507
508    def calculate_ER(self):
509        # type: () -> float
510        """
511        Calculate the effective radius for P(q)*S(q)
512
513        :return: the value of the effective radius
514        """
515        if self._model_info.ER is None:
516            return 1.0
517        else:
518            value, weight = self._dispersion_mesh()
519            fv = self._model_info.ER(*value)
520            #print(values[0].shape, weights.shape, fv.shape)
521            return np.sum(weight * fv) / np.sum(weight)
522
523    def calculate_VR(self):
524        # type: () -> float
525        """
526        Calculate the volf ratio for P(q)*S(q)
527
528        :return: the value of the volf ratio
529        """
530        if self._model_info.VR is None:
531            return 1.0
532        else:
533            value, weight = self._dispersion_mesh()
534            whole, part = self._model_info.VR(*value)
535            return np.sum(weight * part) / np.sum(weight * whole)
536
537    def set_dispersion(self, parameter, dispersion):
538        # type: (str, weights.Dispersion) -> Dict[str, Any]
539        """
540        Set the dispersion object for a model parameter
541
542        :param parameter: name of the parameter [string]
543        :param dispersion: dispersion object of type Dispersion
544        """
545        if parameter.lower() in (s.lower() for s in self.params.keys()):
546            # TODO: Store the disperser object directly in the model.
547            # The current method of creating one on the fly whenever it is
548            # needed is kind of funky.
549            # Note: can't seem to get disperser parameters from sasview
550            # (1) Could create a sasview model that has not yet # been
551            # converted, assign the disperser to one of its polydisperse
552            # parameters, then retrieve the disperser parameters from the
553            # sasview model.  (2) Could write a disperser parameter retriever
554            # in sasview.  (3) Could modify sasview to use sasmodels.weights
555            # dispersers.
556            # For now, rely on the fact that the sasview only ever uses
557            # new dispersers in the set_dispersion call and create a new
558            # one instead of trying to assign parameters.
559            from . import weights
560            disperser = weights.dispersers[dispersion.__class__.__name__]
561            dispersion = weights.MODELS[disperser]()
562            self.dispersion[parameter] = dispersion.get_pars()
563        else:
564            raise ValueError("%r is not a dispersity or orientation parameter")
565
566    def _dispersion_mesh(self):
567        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
568        """
569        Create a mesh grid of dispersion parameters and weights.
570
571        Returns [p1,p2,...],w where pj is a vector of values for parameter j
572        and w is a vector containing the products for weights for each
573        parameter set in the vector.
574        """
575        pars = [self._get_weights(p)
576                for p in self._model_info.parameters.call_parameters
577                if p.type == 'volume']
578        return details.dispersion_mesh(self._model_info, pars)
579
580    def _get_weights(self, par):
581        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
582        """
583        Return dispersion weights for parameter
584        """
585        if par.name not in self.params:
586            if par.name == self.multiplicity_info.control:
587                return [self.multiplicity], []
588            else:
589                return [np.NaN], []
590        elif par.polydisperse:
591            dis = self.dispersion[par.name]
592            value, weight = weights.get_weights(
593                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
594                self.params[par.name], par.limits, par.relative_pd)
595            return value, weight / np.sum(weight)
596        else:
597            return [self.params[par.name]], []
598
599def test_model():
600    # type: () -> float
601    """
602    Test that a sasview model (cylinder) can be run.
603    """
604    Cylinder = _make_standard_model('cylinder')
605    cylinder = Cylinder()
606    return cylinder.evalDistribution([0.1,0.1])
607
608def test_rpa():
609    # type: () -> float
610    """
611    Test that a sasview model (cylinder) can be run.
612    """
613    RPA = _make_standard_model('rpa')
614    rpa = RPA(3)
615    return rpa.evalDistribution([0.1,0.1])
616
617
618def test_model_list():
619    # type: () -> None
620    """
621    Make sure that all models build as sasview models.
622    """
623    from .exception import annotate_exception
624    for name in core.list_models():
625        try:
626            _make_standard_model(name)
627        except:
628            annotate_exception("when loading "+name)
629            raise
630
631if __name__ == "__main__":
632    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.