source: sasmodels/sasmodels/sasview_model.py @ fa5fd8d

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since fa5fd8d was fa5fd8d, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

support number of shells selection in sasview wrapper for onion model

  • Property mode set to 100644
File size: 20.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import make_class
8    from sasmodels.models import cylinder
9    CylinderModel = make_class(cylinder, dtype='single')
10
11The model parameters for sasmodels are different from those in sasview.
12When reloading previously saved models, the parameters should be converted
13using :func:`sasmodels.convert.convert`.
14"""
15from __future__ import print_function
16
17import math
18from copy import deepcopy
19import collections
20import traceback
21import logging
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import generate
28from . import weights
29from . import details
30from . import modelinfo
31
32try:
33    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional
34    from .modelinfo import ModelInfo, Parameter
35    from .kernel import KernelModel
36    MultiplicityInfoType = NamedTuple(
37        'MuliplicityInfo',
38        [("number", int), ("control", str), ("choices", List[str]),
39         ("x_axis_label", str)])
40except ImportError:
41    pass
42
43# TODO: separate x_axis_label from multiplicity info
44# The x-axis label belongs with the profile generating function
45MultiplicityInfo = collections.namedtuple(
46    'MultiplicityInfo',
47    ["number", "control", "choices", "x_axis_label"],
48)
49
50# TODO: figure out how to say that the return type is a subclass
51def load_standard_models():
52    # type: () -> List[type]
53    """
54    Load and return the list of predefined models.
55
56    If there is an error loading a model, then a traceback is logged and the
57    model is not returned.
58    """
59    models = []
60    for name in core.list_models():
61        try:
62            models.append(_make_standard_model(name))
63        except Exception:
64            logging.error(traceback.format_exc())
65    return models
66
67
68def load_custom_model(path):
69    # type: (str) -> type
70    """
71    Load a custom model given the model path.
72    """
73    kernel_module = custom.load_custom_kernel_module(path)
74    model_info = modelinfo.make_model_info(kernel_module)
75    return _make_model_from_info(model_info)
76
77
78def _make_standard_model(name):
79    # type: (str) -> type
80    """
81    Load the sasview model defined by *name*.
82
83    *name* can be a standard model name or a path to a custom model.
84
85    Returns a class that can be used directly as a sasview model.
86    """
87    kernel_module = generate.load_kernel_module(name)
88    model_info = modelinfo.make_model_info(kernel_module)
89    return _make_model_from_info(model_info)
90
91
92def _make_model_from_info(model_info):
93    # type: (ModelInfo) -> type
94    """
95    Convert *model_info* into a SasView model wrapper.
96    """
97    def __init__(self, multiplicity=None):
98        SasviewModel.__init__(self, multiplicity=multiplicity)
99    attrs = _generate_model_attributes(model_info)
100    attrs['__init__'] = __init__
101    ConstructedModel = type(model_info.name, (SasviewModel,), attrs)
102    return ConstructedModel
103
104def _generate_model_attributes(model_info):
105    # type: (ModelInfo) -> Dict[str, Any]
106    """
107    Generate the class attributes for the model.
108
109    This should include all the information necessary to query the model
110    details so that you do not need to instantiate a model to query it.
111
112    All the attributes should be immutable to avoid accidents.
113    """
114    attrs = {}  # type: Dict[str, Any]
115    attrs['_model_info'] = model_info
116    attrs['name'] = model_info.name
117    attrs['description'] = model_info.description
118    attrs['category'] = None
119
120    # TODO: allow model to override axis labels input/output name/unit
121
122    parameters = model_info.parameters
123
124    #self.is_multifunc = False
125    non_fittable = []  # type: List[str]
126    for p in parameters.kernel_parameters:
127        if p.is_control:
128            non_fittable.append(p.name)
129            profile_axes = model_info.profile_axes
130            multiplicity_info = MultiplicityInfo(
131                p.limits[1], p.name, p.choices, profile_axes[0]
132            )
133            break
134    else:
135        multiplicity_info = MultiplicityInfo(0, "", [], "")
136
137    attrs['is_structure_factor'] = model_info.structure_factor
138    attrs['is_form_factor'] = model_info.ER is not None
139    attrs['is_multiplicity_model'] = multiplicity_info[0] > 1
140    attrs['multiplicity_info'] = multiplicity_info
141
142
143    orientation_params = []
144    magnetic_params = []
145    fixed = []
146    for p in parameters.user_parameters():
147        if p.type == 'orientation':
148            orientation_params.append(p.name)
149            orientation_params.append(p.name+".width")
150            fixed.append(p.name+".width")
151        if p.type == 'magnetic':
152            orientation_params.append(p.name)
153            magnetic_params.append(p.name)
154            fixed.append(p.name+".width")
155    attrs['orientation_params'] = tuple(orientation_params)
156    attrs['magnetic_params'] = tuple(magnetic_params)
157    attrs['fixed'] = tuple(fixed)
158
159    attrs['non_fittable'] = tuple(non_fittable)
160
161    return attrs
162
163class SasviewModel(object):
164    """
165    Sasview wrapper for opencl/ctypes model.
166    """
167    # Model parameters for the specific model are set in the class constructor
168    # via the _generate_model_attributes function, which subclasses
169    # SasviewModel.  They are included here for typing and documentation
170    # purposes.
171    _model = None       # type: KernelModel
172    _model_info = None  # type: ModelInfo
173    #: load/save name for the model
174    id = None           # type: str
175    #: display name for the model
176    name = None         # type: str
177    #: short model description
178    description = None  # type: str
179    #: default model category
180    category = None     # type: str
181
182    #: names of the orientation parameters in the order they appear
183    orientation_params = None # type: Sequence[str]
184    #: names of the magnetic parameters in the order they appear
185    magnetic_params = None    # type: Sequence[str]
186    #: names of the fittable parameters
187    fixed = None              # type: Sequence[str]
188    # TODO: the attribute fixed is ill-named
189
190    # Axis labels
191    input_name = "Q"
192    input_unit = "A^{-1}"
193    output_name = "Intensity"
194    output_unit = "cm^{-1}"
195
196    #: default cutoff for polydispersity
197    cutoff = 1e-5
198
199    # Note: Use non-mutable values for class attributes to avoid errors
200    #: parameters that are not fitted
201    non_fittable = ()        # type: Sequence[str]
202
203    #: True if model should appear as a structure factor
204    is_structure_factor = False
205    #: True if model should appear as a form factor
206    is_form_factor = False
207    #: True if model has multiplicity
208    is_multiplicity_model = False
209    #: Mulitplicity information
210    multiplicity_info = None # type: MultiplicityInfoType
211
212    # Per-instance variables
213    #: parameter {name: value} mapping
214    params = None      # type: Dict[str, float]
215    #: values for dispersion width, npts, nsigmas and type
216    dispersion = None  # type: Dict[str, Any]
217    #: units and limits for each parameter
218    details = None     # type: Mapping[str, Tuple(str, float, float)]
219    #: multiplicity used, or None if no multiplicity controls
220    multiplicity = None     # type: Optional[int]
221
222    def __init__(self, multiplicity=None):
223        # type: () -> None
224        print("initializing", self.name)
225        #raise Exception("first initialization")
226        self._model = None
227
228        ## _persistency_dict is used by sas.perspectives.fitting.basepage
229        ## to store dispersity reference.
230        ## TODO: _persistency_dict to persistency_dict throughout sasview
231        self._persistency_dict = {}
232
233        # TODO: refactor multiplicity info
234        # TODO: separate profile view from multiplicity
235        # The button label, x and y axis labels and scale need to be under
236        # the control of the model, not the fit page.  Maximum flexibility,
237        # the fit page would supply the canvas and the profile could plot
238        # how it wants, but this assumes matplotlib.  Next level is that
239        # we provide some sort of data description including title, labels
240        # and lines to plot.
241
242        # TODO: refactor multiplicity to encompass variants
243        # TODO: dispersion should be a class
244        self.multiplicity = multiplicity
245        self.params = collections.OrderedDict()
246        self.dispersion = {}
247        self.details = {}
248        config = ({self.multiplicity_info.control: multiplicity}
249                  if multiplicity is not None else {})
250        for p in self._model_info.parameters.user_parameters(config):
251            # Don't include multiplicity in the list of parameters
252            if p.name == self.multiplicity_info.control:
253                continue
254            self.params[p.name] = p.default
255            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
256            if p.polydisperse:
257                self.details[p.id+".width"] = [
258                    "", 0.0, 1.0 if p.relative_pd else np.inf
259                ]
260                self.dispersion[p.name] = {
261                    'width': 0,
262                    'npts': 35,
263                    'nsigmas': 3,
264                    'type': 'gaussian',
265                }
266
267    def __get_state__(self):
268        # type: () -> Dict[str, Any]
269        state = self.__dict__.copy()
270        state.pop('_model')
271        # May need to reload model info on set state since it has pointers
272        # to python implementations of Iq, etc.
273        #state.pop('_model_info')
274        return state
275
276    def __set_state__(self, state):
277        # type: (Dict[str, Any]) -> None
278        self.__dict__ = state
279        self._model = None
280
281    def __str__(self):
282        # type: () -> str
283        """
284        :return: string representation
285        """
286        return self.name
287
288    def is_fittable(self, par_name):
289        # type: (str) -> bool
290        """
291        Check if a given parameter is fittable or not
292
293        :param par_name: the parameter name to check
294        """
295        return par_name.lower() in self.fixed
296        #For the future
297        #return self.params[str(par_name)].is_fittable()
298
299
300    def getProfile(self):
301        # type: () -> (np.ndarray, np.ndarray)
302        """
303        Get SLD profile
304
305        : return: (z, beta) where z is a list of depth of the transition points
306                beta is a list of the corresponding SLD values
307        """
308        args = []
309        for p in self._model_info.parameters.kernel_parameters:
310            if p.id == self.multiplicity_info.control:
311                args.append(self.multiplicity)
312            elif p.length == 1:
313                args.append(self.params.get(p.id, np.NaN))
314            else:
315                args.append([self.params.get(p.id+str(k), np.NaN)
316                             for k in range(1,p.length+1)])
317        return self._model_info.profile(*args)
318
319    def setParam(self, name, value):
320        # type: (str, float) -> None
321        """
322        Set the value of a model parameter
323
324        :param name: name of the parameter
325        :param value: value of the parameter
326
327        """
328        # Look for dispersion parameters
329        toks = name.split('.')
330        if len(toks) == 2:
331            for item in self.dispersion.keys():
332                if item.lower() == toks[0].lower():
333                    for par in self.dispersion[item]:
334                        if par.lower() == toks[1].lower():
335                            self.dispersion[item][par] = value
336                            return
337        else:
338            # Look for standard parameter
339            for item in self.params.keys():
340                if item.lower() == name.lower():
341                    self.params[item] = value
342                    return
343
344        raise ValueError("Model does not contain parameter %s" % name)
345
346    def getParam(self, name):
347        # type: (str) -> float
348        """
349        Set the value of a model parameter
350
351        :param name: name of the parameter
352
353        """
354        # Look for dispersion parameters
355        toks = name.split('.')
356        if len(toks) == 2:
357            for item in self.dispersion.keys():
358                if item.lower() == toks[0].lower():
359                    for par in self.dispersion[item]:
360                        if par.lower() == toks[1].lower():
361                            return self.dispersion[item][par]
362        else:
363            # Look for standard parameter
364            for item in self.params.keys():
365                if item.lower() == name.lower():
366                    return self.params[item]
367
368        raise ValueError("Model does not contain parameter %s" % name)
369
370    def getParamList(self):
371        # type: () - > Sequence[str]
372        """
373        Return a list of all available parameters for the model
374        """
375        param_list = self.params.keys()
376        # WARNING: Extending the list with the dispersion parameters
377        param_list.extend(self.getDispParamList())
378        return param_list
379
380    def getDispParamList(self):
381        # type: () - > Sequence[str]
382        """
383        Return a list of polydispersity parameters for the model
384        """
385        # TODO: fix test so that parameter order doesn't matter
386        ret = ['%s.%s' % (p.name.lower(), ext)
387               for p in self._model_info.parameters.user_parameters()
388               for ext in ('npts', 'nsigmas', 'width')
389               if p.polydisperse]
390        #print(ret)
391        return ret
392
393    def clone(self):
394        # type: () - > "SasviewModel"
395        """ Return a identical copy of self """
396        return deepcopy(self)
397
398    def run(self, x=0.0):
399        # type: (Union[float, (float, float), List[float]]) -> float
400        """
401        Evaluate the model
402
403        :param x: input q, or [q,phi]
404
405        :return: scattering function P(q)
406
407        **DEPRECATED**: use calculate_Iq instead
408        """
409        if isinstance(x, (list, tuple)):
410            # pylint: disable=unpacking-non-sequence
411            q, phi = x
412            return self.calculate_Iq([q * math.cos(phi)],
413                                     [q * math.sin(phi)])[0]
414        else:
415            return self.calculate_Iq([float(x)])[0]
416
417
418    def runXY(self, x=0.0):
419        # type: (Union[float, (float, float), List[float]]) -> float
420        """
421        Evaluate the model in cartesian coordinates
422
423        :param x: input q, or [qx, qy]
424
425        :return: scattering function P(q)
426
427        **DEPRECATED**: use calculate_Iq instead
428        """
429        if isinstance(x, (list, tuple)):
430            return self.calculate_Iq([float(x[0])], [float(x[1])])[0]
431        else:
432            return self.calculate_Iq([float(x)])[0]
433
434    def evalDistribution(self, qdist):
435        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]) -> np.ndarray
436        r"""
437        Evaluate a distribution of q-values.
438
439        :param qdist: array of q or a list of arrays [qx,qy]
440
441        * For 1D, a numpy array is expected as input
442
443        ::
444
445            evalDistribution(q)
446
447          where *q* is a numpy array.
448
449        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
450
451        ::
452
453              qx = [ qx[0], qx[1], qx[2], ....]
454              qy = [ qy[0], qy[1], qy[2], ....]
455
456        If the model is 1D only, then
457
458        .. math::
459
460            q = \sqrt{q_x^2+q_y^2}
461
462        """
463        if isinstance(qdist, (list, tuple)):
464            # Check whether we have a list of ndarrays [qx,qy]
465            qx, qy = qdist
466            if not self._model_info.parameters.has_2d:
467                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
468            else:
469                return self.calculate_Iq(qx, qy)
470
471        elif isinstance(qdist, np.ndarray):
472            # We have a simple 1D distribution of q-values
473            return self.calculate_Iq(qdist)
474
475        else:
476            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
477                            % type(qdist))
478
479    def calculate_Iq(self, qx, qy=None):
480        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
481        """
482        Calculate Iq for one set of q with the current parameters.
483
484        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
485
486        This should NOT be used for fitting since it copies the *q* vectors
487        to the card for each evaluation.
488        """
489        if self._model is None:
490            self._model = core.build_model(self._model_info)
491        if qy is not None:
492            q_vectors = [np.asarray(qx), np.asarray(qy)]
493        else:
494            q_vectors = [np.asarray(qx)]
495        kernel = self._model.make_kernel(q_vectors)
496        pairs = [self._get_weights(p)
497                 for p in self._model_info.parameters.call_parameters]
498        call_details, weight, value = details.build_details(kernel, pairs)
499        result = kernel(call_details, weight, value, cutoff=self.cutoff)
500        kernel.release()
501        return result
502
503    def calculate_ER(self):
504        # type: () -> float
505        """
506        Calculate the effective radius for P(q)*S(q)
507
508        :return: the value of the effective radius
509        """
510        if self._model_info.ER is None:
511            return 1.0
512        else:
513            value, weight = self._dispersion_mesh()
514            fv = self._model_info.ER(*value)
515            #print(values[0].shape, weights.shape, fv.shape)
516            return np.sum(weight * fv) / np.sum(weight)
517
518    def calculate_VR(self):
519        # type: () -> float
520        """
521        Calculate the volf ratio for P(q)*S(q)
522
523        :return: the value of the volf ratio
524        """
525        if self._model_info.VR is None:
526            return 1.0
527        else:
528            value, weight = self._dispersion_mesh()
529            whole, part = self._model_info.VR(*value)
530            return np.sum(weight * part) / np.sum(weight * whole)
531
532    def set_dispersion(self, parameter, dispersion):
533        # type: (str, weights.Dispersion) -> Dict[str, Any]
534        """
535        Set the dispersion object for a model parameter
536
537        :param parameter: name of the parameter [string]
538        :param dispersion: dispersion object of type Dispersion
539        """
540        if parameter.lower() in (s.lower() for s in self.params.keys()):
541            # TODO: Store the disperser object directly in the model.
542            # The current method of creating one on the fly whenever it is
543            # needed is kind of funky.
544            # Note: can't seem to get disperser parameters from sasview
545            # (1) Could create a sasview model that has not yet # been
546            # converted, assign the disperser to one of its polydisperse
547            # parameters, then retrieve the disperser parameters from the
548            # sasview model.  (2) Could write a disperser parameter retriever
549            # in sasview.  (3) Could modify sasview to use sasmodels.weights
550            # dispersers.
551            # For now, rely on the fact that the sasview only ever uses
552            # new dispersers in the set_dispersion call and create a new
553            # one instead of trying to assign parameters.
554            from . import weights
555            disperser = weights.dispersers[dispersion.__class__.__name__]
556            dispersion = weights.models[disperser]()
557            self.dispersion[parameter] = dispersion.get_pars()
558        else:
559            raise ValueError("%r is not a dispersity or orientation parameter")
560
561    def _dispersion_mesh(self):
562        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
563        """
564        Create a mesh grid of dispersion parameters and weights.
565
566        Returns [p1,p2,...],w where pj is a vector of values for parameter j
567        and w is a vector containing the products for weights for each
568        parameter set in the vector.
569        """
570        pars = [self._get_weights(p)
571                for p in self._model_info.parameters.call_parameters
572                if p.type == 'volume']
573        return details.dispersion_mesh(self._model_info, pars)
574
575    def _get_weights(self, par):
576        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
577        """
578        Return dispersion weights for parameter
579        """
580        if par.name not in self.params:
581            if par.name == self.multiplicity_info.control:
582                return [self.multiplicity], []
583            else:
584                return [np.NaN], []
585        elif par.polydisperse:
586            dis = self.dispersion[par.name]
587            value, weight = weights.get_weights(
588                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
589                self.params[par.name], par.limits, par.relative_pd)
590            return value, weight / np.sum(weight)
591        else:
592            return [self.params[par.name]], []
593
594def test_model():
595    # type: () -> float
596    """
597    Test that a sasview model (cylinder) can be run.
598    """
599    Cylinder = _make_standard_model('cylinder')
600    cylinder = Cylinder()
601    return cylinder.evalDistribution([0.1,0.1])
602
603
604def test_model_list():
605    # type: () -> None
606    """
607    Make sure that all models build as sasview models.
608    """
609    from .exception import annotate_exception
610    for name in core.list_models():
611        try:
612            _make_standard_model(name)
613        except:
614            annotate_exception("when loading "+name)
615            raise
616
617if __name__ == "__main__":
618    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.