source: sasmodels/sasmodels/sasview_model.py @ 60f03de

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 60f03de was 60f03de, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

still more type hinting

  • Property mode set to 100644
File size: 21.6 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29from . import details
30from . import modelinfo
31
32try:
33    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
34    from .modelinfo import ModelInfo, Parameter
35    from .kernel import KernelModel
36    MultiplicityInfoType = NamedTuple(
37        'MuliplicityInfo',
38        [("number", int), ("control", str), ("choices", List[str]),
39         ("x_axis_label", str)])
40    SasviewModelType = Callable[[int], "SasviewModel"]
41except ImportError:
42    pass
43
44# TODO: separate x_axis_label from multiplicity info
45# The profile x-axis label belongs with the profile generating function
46MultiplicityInfo = collections.namedtuple(
47    'MultiplicityInfo',
48    ["number", "control", "choices", "x_axis_label"],
49)
50
51# TODO: figure out how to say that the return type is a subclass
52def load_standard_models():
53    # type: () -> List[SasviewModelType]
54    """
55    Load and return the list of predefined models.
56
57    If there is an error loading a model, then a traceback is logged and the
58    model is not returned.
59    """
60    models = []
61    for name in core.list_models():
62        try:
63            models.append(_make_standard_model(name))
64        except Exception:
65            logging.error(traceback.format_exc())
66    return models
67
68
69def load_custom_model(path):
70    # type: (str) -> SasviewModelType
71    """
72    Load a custom model given the model path.
73    """
74    kernel_module = custom.load_custom_kernel_module(path)
75    model_info = modelinfo.make_model_info(kernel_module)
76    return _make_model_from_info(model_info)
77
78
79def _make_standard_model(name):
80    # type: (str) -> SasviewModelType
81    """
82    Load the sasview model defined by *name*.
83
84    *name* can be a standard model name or a path to a custom model.
85
86    Returns a class that can be used directly as a sasview model.
87    """
88    kernel_module = generate.load_kernel_module(name)
89    model_info = modelinfo.make_model_info(kernel_module)
90    return _make_model_from_info(model_info)
91
92
93def _make_model_from_info(model_info):
94    # type: (ModelInfo) -> SasviewModelType
95    """
96    Convert *model_info* into a SasView model wrapper.
97    """
98    def __init__(self, multiplicity=None):
99        SasviewModel.__init__(self, multiplicity=multiplicity)
100    attrs = _generate_model_attributes(model_info)
101    attrs['__init__'] = __init__
102    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
103    return ConstructedModel
104
105def _generate_model_attributes(model_info):
106    # type: (ModelInfo) -> Dict[str, Any]
107    """
108    Generate the class attributes for the model.
109
110    This should include all the information necessary to query the model
111    details so that you do not need to instantiate a model to query it.
112
113    All the attributes should be immutable to avoid accidents.
114    """
115
116    # TODO: allow model to override axis labels input/output name/unit
117
118    # Process multiplicity
119    non_fittable = []  # type: List[str]
120    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
121    variants = MultiplicityInfo(0, "", [], xlabel)
122    for p in model_info.parameters.kernel_parameters:
123        if p.name == model_info.control:
124            non_fittable.append(p.name)
125            variants = MultiplicityInfo(
126                len(p.choices), p.name, p.choices, xlabel
127            )
128            break
129        elif p.is_control:
130            non_fittable.append(p.name)
131            variants = MultiplicityInfo(
132                int(p.limits[1]), p.name, p.choices, xlabel
133            )
134            break
135
136    # Organize parameter sets
137    orientation_params = []
138    magnetic_params = []
139    fixed = []
140    for p in model_info.parameters.user_parameters():
141        if p.type == 'orientation':
142            orientation_params.append(p.name)
143            orientation_params.append(p.name+".width")
144            fixed.append(p.name+".width")
145        if p.type == 'magnetic':
146            orientation_params.append(p.name)
147            magnetic_params.append(p.name)
148            fixed.append(p.name+".width")
149
150    # Build class dictionary
151    attrs = {}  # type: Dict[str, Any]
152    attrs['_model_info'] = model_info
153    attrs['name'] = model_info.name
154    attrs['id'] = model_info.id
155    attrs['description'] = model_info.description
156    attrs['category'] = model_info.category
157    attrs['is_structure_factor'] = model_info.structure_factor
158    attrs['is_form_factor'] = model_info.ER is not None
159    attrs['is_multiplicity_model'] = variants[0] > 1
160    attrs['multiplicity_info'] = variants
161    attrs['orientation_params'] = tuple(orientation_params)
162    attrs['magnetic_params'] = tuple(magnetic_params)
163    attrs['fixed'] = tuple(fixed)
164    attrs['non_fittable'] = tuple(non_fittable)
165
166    return attrs
167
168class SasviewModel(object):
169    """
170    Sasview wrapper for opencl/ctypes model.
171    """
172    # Model parameters for the specific model are set in the class constructor
173    # via the _generate_model_attributes function, which subclasses
174    # SasviewModel.  They are included here for typing and documentation
175    # purposes.
176    _model = None       # type: KernelModel
177    _model_info = None  # type: ModelInfo
178    #: load/save name for the model
179    id = None           # type: str
180    #: display name for the model
181    name = None         # type: str
182    #: short model description
183    description = None  # type: str
184    #: default model category
185    category = None     # type: str
186
187    #: names of the orientation parameters in the order they appear
188    orientation_params = None # type: Sequence[str]
189    #: names of the magnetic parameters in the order they appear
190    magnetic_params = None    # type: Sequence[str]
191    #: names of the fittable parameters
192    fixed = None              # type: Sequence[str]
193    # TODO: the attribute fixed is ill-named
194
195    # Axis labels
196    input_name = "Q"
197    input_unit = "A^{-1}"
198    output_name = "Intensity"
199    output_unit = "cm^{-1}"
200
201    #: default cutoff for polydispersity
202    cutoff = 1e-5
203
204    # Note: Use non-mutable values for class attributes to avoid errors
205    #: parameters that are not fitted
206    non_fittable = ()        # type: Sequence[str]
207
208    #: True if model should appear as a structure factor
209    is_structure_factor = False
210    #: True if model should appear as a form factor
211    is_form_factor = False
212    #: True if model has multiplicity
213    is_multiplicity_model = False
214    #: Mulitplicity information
215    multiplicity_info = None # type: MultiplicityInfoType
216
217    # Per-instance variables
218    #: parameter {name: value} mapping
219    params = None      # type: Dict[str, float]
220    #: values for dispersion width, npts, nsigmas and type
221    dispersion = None  # type: Dict[str, Any]
222    #: units and limits for each parameter
223    details = None     # type: Dict[str, Sequence[Any]]
224    #                  # actual type is Dict[str, List[str, float, float]]
225    #: multiplicity value, or None if no multiplicity on the model
226    multiplicity = None     # type: Optional[int]
227    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
228    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
229
230    def __init__(self, multiplicity=None):
231        # type: (Optional[int]) -> None
232
233        # TODO: _persistency_dict to persistency_dict throughout sasview
234        # TODO: refactor multiplicity to encompass variants
235        # TODO: dispersion should be a class
236        # TODO: refactor multiplicity info
237        # TODO: separate profile view from multiplicity
238        # The button label, x and y axis labels and scale need to be under
239        # the control of the model, not the fit page.  Maximum flexibility,
240        # the fit page would supply the canvas and the profile could plot
241        # how it wants, but this assumes matplotlib.  Next level is that
242        # we provide some sort of data description including title, labels
243        # and lines to plot.
244
245        # Get the list of hidden parameters given the mulitplicity
246        # Don't include multiplicity in the list of parameters
247        self.multiplicity = multiplicity
248        if multiplicity is not None:
249            hidden = self._model_info.get_hidden_parameters(multiplicity)
250            hidden |= set([self.multiplicity_info.control])
251        else:
252            hidden = set()
253
254        self._persistency_dict = {}
255        self.params = collections.OrderedDict()
256        self.dispersion = {}
257        self.details = {}
258        for p in self._model_info.parameters.user_parameters():
259            if p.name in hidden:
260                continue
261            self.params[p.name] = p.default
262            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
263            if p.polydisperse:
264                self.details[p.id+".width"] = [
265                    "", 0.0, 1.0 if p.relative_pd else np.inf
266                ]
267                self.dispersion[p.name] = {
268                    'width': 0,
269                    'npts': 35,
270                    'nsigmas': 3,
271                    'type': 'gaussian',
272                }
273
274    def __get_state__(self):
275        # type: () -> Dict[str, Any]
276        state = self.__dict__.copy()
277        state.pop('_model')
278        # May need to reload model info on set state since it has pointers
279        # to python implementations of Iq, etc.
280        #state.pop('_model_info')
281        return state
282
283    def __set_state__(self, state):
284        # type: (Dict[str, Any]) -> None
285        self.__dict__ = state
286        self._model = None
287
288    def __str__(self):
289        # type: () -> str
290        """
291        :return: string representation
292        """
293        return self.name
294
295    def is_fittable(self, par_name):
296        # type: (str) -> bool
297        """
298        Check if a given parameter is fittable or not
299
300        :param par_name: the parameter name to check
301        """
302        return par_name.lower() in self.fixed
303        #For the future
304        #return self.params[str(par_name)].is_fittable()
305
306
307    def getProfile(self):
308        # type: () -> (np.ndarray, np.ndarray)
309        """
310        Get SLD profile
311
312        : return: (z, beta) where z is a list of depth of the transition points
313                beta is a list of the corresponding SLD values
314        """
315        args = [] # type: List[Union[float, np.ndarray]]
316        for p in self._model_info.parameters.kernel_parameters:
317            if p.id == self.multiplicity_info.control:
318                args.append(float(self.multiplicity))
319            elif p.length == 1:
320                args.append(self.params.get(p.id, np.NaN))
321            else:
322                args.append([self.params.get(p.id+str(k), np.NaN)
323                             for k in range(1,p.length+1)])
324        return self._model_info.profile(*args)
325
326    def setParam(self, name, value):
327        # type: (str, float) -> None
328        """
329        Set the value of a model parameter
330
331        :param name: name of the parameter
332        :param value: value of the parameter
333
334        """
335        # Look for dispersion parameters
336        toks = name.split('.')
337        if len(toks) == 2:
338            for item in self.dispersion.keys():
339                if item.lower() == toks[0].lower():
340                    for par in self.dispersion[item]:
341                        if par.lower() == toks[1].lower():
342                            self.dispersion[item][par] = value
343                            return
344        else:
345            # Look for standard parameter
346            for item in self.params.keys():
347                if item.lower() == name.lower():
348                    self.params[item] = value
349                    return
350
351        raise ValueError("Model does not contain parameter %s" % name)
352
353    def getParam(self, name):
354        # type: (str) -> float
355        """
356        Set the value of a model parameter
357
358        :param name: name of the parameter
359
360        """
361        # Look for dispersion parameters
362        toks = name.split('.')
363        if len(toks) == 2:
364            for item in self.dispersion.keys():
365                if item.lower() == toks[0].lower():
366                    for par in self.dispersion[item]:
367                        if par.lower() == toks[1].lower():
368                            return self.dispersion[item][par]
369        else:
370            # Look for standard parameter
371            for item in self.params.keys():
372                if item.lower() == name.lower():
373                    return self.params[item]
374
375        raise ValueError("Model does not contain parameter %s" % name)
376
377    def getParamList(self):
378        # type: () -> Sequence[str]
379        """
380        Return a list of all available parameters for the model
381        """
382        param_list = list(self.params.keys())
383        # WARNING: Extending the list with the dispersion parameters
384        param_list.extend(self.getDispParamList())
385        return param_list
386
387    def getDispParamList(self):
388        # type: () -> Sequence[str]
389        """
390        Return a list of polydispersity parameters for the model
391        """
392        # TODO: fix test so that parameter order doesn't matter
393        ret = ['%s.%s' % (p.name.lower(), ext)
394               for p in self._model_info.parameters.user_parameters()
395               for ext in ('npts', 'nsigmas', 'width')
396               if p.polydisperse]
397        #print(ret)
398        return ret
399
400    def clone(self):
401        # type: () -> "SasviewModel"
402        """ Return a identical copy of self """
403        return deepcopy(self)
404
405    def run(self, x=0.0):
406        # type: (Union[float, (float, float), List[float]]) -> float
407        """
408        Evaluate the model
409
410        :param x: input q, or [q,phi]
411
412        :return: scattering function P(q)
413
414        **DEPRECATED**: use calculate_Iq instead
415        """
416        if isinstance(x, (list, tuple)):
417            # pylint: disable=unpacking-non-sequence
418            q, phi = x
419            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
420        else:
421            return self.calculate_Iq([x])[0]
422
423
424    def runXY(self, x=0.0):
425        # type: (Union[float, (float, float), List[float]]) -> float
426        """
427        Evaluate the model in cartesian coordinates
428
429        :param x: input q, or [qx, qy]
430
431        :return: scattering function P(q)
432
433        **DEPRECATED**: use calculate_Iq instead
434        """
435        if isinstance(x, (list, tuple)):
436            return self.calculate_Iq([x[0]], [x[1]])[0]
437        else:
438            return self.calculate_Iq([x])[0]
439
440    def evalDistribution(self, qdist):
441        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
442        r"""
443        Evaluate a distribution of q-values.
444
445        :param qdist: array of q or a list of arrays [qx,qy]
446
447        * For 1D, a numpy array is expected as input
448
449        ::
450
451            evalDistribution(q)
452
453          where *q* is a numpy array.
454
455        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
456
457        ::
458
459              qx = [ qx[0], qx[1], qx[2], ....]
460              qy = [ qy[0], qy[1], qy[2], ....]
461
462        If the model is 1D only, then
463
464        .. math::
465
466            q = \sqrt{q_x^2+q_y^2}
467
468        """
469        if isinstance(qdist, (list, tuple)):
470            # Check whether we have a list of ndarrays [qx,qy]
471            qx, qy = qdist
472            if not self._model_info.parameters.has_2d:
473                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
474            else:
475                return self.calculate_Iq(qx, qy)
476
477        elif isinstance(qdist, np.ndarray):
478            # We have a simple 1D distribution of q-values
479            return self.calculate_Iq(qdist)
480
481        else:
482            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
483                            % type(qdist))
484
485    def calculate_Iq(self, qx, qy=None):
486        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
487        """
488        Calculate Iq for one set of q with the current parameters.
489
490        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
491
492        This should NOT be used for fitting since it copies the *q* vectors
493        to the card for each evaluation.
494        """
495        if self._model is None:
496            self._model = core.build_model(self._model_info)
497        if qy is not None:
498            q_vectors = [np.asarray(qx), np.asarray(qy)]
499        else:
500            q_vectors = [np.asarray(qx)]
501        kernel = self._model.make_kernel(q_vectors)
502        pairs = [self._get_weights(p)
503                 for p in self._model_info.parameters.call_parameters]
504        call_details, weight, value = details.build_details(kernel, pairs)
505        result = kernel(call_details, weight, value, cutoff=self.cutoff)
506        kernel.release()
507        return result
508
509    def calculate_ER(self):
510        # type: () -> float
511        """
512        Calculate the effective radius for P(q)*S(q)
513
514        :return: the value of the effective radius
515        """
516        if self._model_info.ER is None:
517            return 1.0
518        else:
519            value, weight = self._dispersion_mesh()
520            fv = self._model_info.ER(*value)
521            #print(values[0].shape, weights.shape, fv.shape)
522            return np.sum(weight * fv) / np.sum(weight)
523
524    def calculate_VR(self):
525        # type: () -> float
526        """
527        Calculate the volf ratio for P(q)*S(q)
528
529        :return: the value of the volf ratio
530        """
531        if self._model_info.VR is None:
532            return 1.0
533        else:
534            value, weight = self._dispersion_mesh()
535            whole, part = self._model_info.VR(*value)
536            return np.sum(weight * part) / np.sum(weight * whole)
537
538    def set_dispersion(self, parameter, dispersion):
539        # type: (str, weights.Dispersion) -> Dict[str, Any]
540        """
541        Set the dispersion object for a model parameter
542
543        :param parameter: name of the parameter [string]
544        :param dispersion: dispersion object of type Dispersion
545        """
546        if parameter.lower() in (s.lower() for s in self.params.keys()):
547            # TODO: Store the disperser object directly in the model.
548            # The current method of creating one on the fly whenever it is
549            # needed is kind of funky.
550            # Note: can't seem to get disperser parameters from sasview
551            # (1) Could create a sasview model that has not yet # been
552            # converted, assign the disperser to one of its polydisperse
553            # parameters, then retrieve the disperser parameters from the
554            # sasview model.  (2) Could write a disperser parameter retriever
555            # in sasview.  (3) Could modify sasview to use sasmodels.weights
556            # dispersers.
557            # For now, rely on the fact that the sasview only ever uses
558            # new dispersers in the set_dispersion call and create a new
559            # one instead of trying to assign parameters.
560            from . import weights
561            disperser = weights.dispersers[dispersion.__class__.__name__]
562            dispersion = weights.MODELS[disperser]()
563            self.dispersion[parameter] = dispersion.get_pars()
564        else:
565            raise ValueError("%r is not a dispersity or orientation parameter")
566
567    def _dispersion_mesh(self):
568        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
569        """
570        Create a mesh grid of dispersion parameters and weights.
571
572        Returns [p1,p2,...],w where pj is a vector of values for parameter j
573        and w is a vector containing the products for weights for each
574        parameter set in the vector.
575        """
576        pars = [self._get_weights(p)
577                for p in self._model_info.parameters.call_parameters
578                if p.type == 'volume']
579        return details.dispersion_mesh(self._model_info, pars)
580
581    def _get_weights(self, par):
582        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
583        """
584        Return dispersion weights for parameter
585        """
586        if par.name not in self.params:
587            if par.name == self.multiplicity_info.control:
588                return [self.multiplicity], []
589            else:
590                return [np.NaN], []
591        elif par.polydisperse:
592            dis = self.dispersion[par.name]
593            value, weight = weights.get_weights(
594                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
595                self.params[par.name], par.limits, par.relative_pd)
596            return value, weight / np.sum(weight)
597        else:
598            return [self.params[par.name]], []
599
600def test_model():
601    # type: () -> float
602    """
603    Test that a sasview model (cylinder) can be run.
604    """
605    Cylinder = _make_standard_model('cylinder')
606    cylinder = Cylinder()
607    return cylinder.evalDistribution([0.1,0.1])
608
609def test_rpa():
610    # type: () -> float
611    """
612    Test that a sasview model (cylinder) can be run.
613    """
614    RPA = _make_standard_model('rpa')
615    rpa = RPA(3)
616    return rpa.evalDistribution([0.1,0.1])
617
618
619def test_model_list():
620    # type: () -> None
621    """
622    Make sure that all models build as sasview models.
623    """
624    from .exception import annotate_exception
625    for name in core.list_models():
626        try:
627            _make_standard_model(name)
628        except:
629            annotate_exception("when loading "+name)
630            raise
631
632if __name__ == "__main__":
633    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.