source: sasmodels/sasmodels/sasview_model.py @ bde38b5

core_shell_microgelscostrafo411magnetic_modelrelease_v0.94release_v0.95ticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bde38b5 was bde38b5, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

simplify kernel calling

  • Property mode set to 100644
File size: 24.0 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[9457498]17from os.path import basename, splitext
[ce27e21]18
[7ae2b7f]19import numpy as np  # type: ignore
[ce27e21]20
[aa4946b]21from . import core
[4d76711]22from . import custom
[72a081d]23from . import generate
[fb5914f]24from . import weights
[6d6508e]25from . import modelinfo
[bde38b5]26from .details import make_kernel_args, dispersion_mesh
[ff7119b]27
[fa5fd8d]28try:
[60f03de]29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
[60f03de]36    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]37except ImportError:
38    pass
39
[c95dfc63]40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
[fd19811]60        module_attrs = {old_name: find_model(new_name)}
[c95dfc63]61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
[9457498]66
[fa5fd8d]67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
[92d38285]73MODELS = {}
74def find_model(modelname):
[b32dafd]75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
[92d38285]80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
[56b2687]89
[fa5fd8d]90# TODO: figure out how to say that the return type is a subclass
[4d76711]91def load_standard_models():
[60f03de]92    # type: () -> List[SasviewModelType]
[4d76711]93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
[92d38285]102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
[ee8f734]104        except Exception:
[4d76711]105            logging.error(traceback.format_exc())
[c95dfc63]106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
[4d76711]109    return models
[de97440]110
[4d76711]111
112def load_custom_model(path):
[60f03de]113    # type: (str) -> SasviewModelType
[4d76711]114    """
115    Load a custom model given the model path.
[ff7119b]116    """
[4d76711]117    kernel_module = custom.load_custom_kernel_module(path)
[92d38285]118    try:
119        model = kernel_module.Model
[9457498]120        # Old style models do not set the name in the class attributes, so
121        # set it here; this name will be overridden when the object is created
122        # with an instance variable that has the same value.
123        if model.name == "":
124            model.name = splitext(basename(path))[0]
[92d38285]125    except AttributeError:
[56b2687]126        model_info = modelinfo.make_model_info(kernel_module)
[92d38285]127        model = _make_model_from_info(model_info)
128    MODELS[model.name] = model
129    return model
[4d76711]130
[87985ca]131
[4d76711]132def _make_standard_model(name):
[60f03de]133    # type: (str) -> SasviewModelType
[ff7119b]134    """
[4d76711]135    Load the sasview model defined by *name*.
[72a081d]136
[4d76711]137    *name* can be a standard model name or a path to a custom model.
[87985ca]138
[4d76711]139    Returns a class that can be used directly as a sasview model.
[ff7119b]140    """
[4d76711]141    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]142    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]143    return _make_model_from_info(model_info)
[72a081d]144
145
[4d76711]146def _make_model_from_info(model_info):
[60f03de]147    # type: (ModelInfo) -> SasviewModelType
[4d76711]148    """
149    Convert *model_info* into a SasView model wrapper.
150    """
[fa5fd8d]151    def __init__(self, multiplicity=None):
152        SasviewModel.__init__(self, multiplicity=multiplicity)
153    attrs = _generate_model_attributes(model_info)
154    attrs['__init__'] = __init__
[60f03de]155    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
[ce27e21]156    return ConstructedModel
157
[fa5fd8d]158def _generate_model_attributes(model_info):
159    # type: (ModelInfo) -> Dict[str, Any]
160    """
161    Generate the class attributes for the model.
162
163    This should include all the information necessary to query the model
164    details so that you do not need to instantiate a model to query it.
165
166    All the attributes should be immutable to avoid accidents.
167    """
168
169    # TODO: allow model to override axis labels input/output name/unit
170
[a18c5b3]171    # Process multiplicity
[fa5fd8d]172    non_fittable = []  # type: List[str]
[04045f4]173    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
174    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]175    for p in model_info.parameters.kernel_parameters:
[04045f4]176        if p.name == model_info.control:
[fa5fd8d]177            non_fittable.append(p.name)
[04045f4]178            variants = MultiplicityInfo(
[ce176ca]179                len(p.choices) if p.choices else int(p.limits[1]),
180                p.name, p.choices, xlabel
[fa5fd8d]181            )
182            break
183
[50ec515]184    # Only a single drop-down list parameter available
185    fun_list = []
186    for p in model_info.parameters.kernel_parameters:
187        if p.choices:
188            fun_list = p.choices
189            if p.length > 1:
190                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
191            break
192
[a18c5b3]193    # Organize parameter sets
[fa5fd8d]194    orientation_params = []
195    magnetic_params = []
196    fixed = []
[a18c5b3]197    for p in model_info.parameters.user_parameters():
[fa5fd8d]198        if p.type == 'orientation':
199            orientation_params.append(p.name)
200            orientation_params.append(p.name+".width")
201            fixed.append(p.name+".width")
[32e3c9b]202        elif p.type == 'magnetic':
[fa5fd8d]203            orientation_params.append(p.name)
204            magnetic_params.append(p.name)
205            fixed.append(p.name+".width")
[a18c5b3]206
[32e3c9b]207
[a18c5b3]208    # Build class dictionary
209    attrs = {}  # type: Dict[str, Any]
210    attrs['_model_info'] = model_info
211    attrs['name'] = model_info.name
212    attrs['id'] = model_info.id
213    attrs['description'] = model_info.description
214    attrs['category'] = model_info.category
215    attrs['is_structure_factor'] = model_info.structure_factor
216    attrs['is_form_factor'] = model_info.ER is not None
217    attrs['is_multiplicity_model'] = variants[0] > 1
218    attrs['multiplicity_info'] = variants
[fa5fd8d]219    attrs['orientation_params'] = tuple(orientation_params)
220    attrs['magnetic_params'] = tuple(magnetic_params)
221    attrs['fixed'] = tuple(fixed)
222    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]223    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]224
225    return attrs
[4d76711]226
[ce27e21]227class SasviewModel(object):
228    """
229    Sasview wrapper for opencl/ctypes model.
230    """
[fa5fd8d]231    # Model parameters for the specific model are set in the class constructor
232    # via the _generate_model_attributes function, which subclasses
233    # SasviewModel.  They are included here for typing and documentation
234    # purposes.
235    _model = None       # type: KernelModel
236    _model_info = None  # type: ModelInfo
237    #: load/save name for the model
238    id = None           # type: str
239    #: display name for the model
240    name = None         # type: str
241    #: short model description
242    description = None  # type: str
243    #: default model category
244    category = None     # type: str
245
246    #: names of the orientation parameters in the order they appear
247    orientation_params = None # type: Sequence[str]
248    #: names of the magnetic parameters in the order they appear
249    magnetic_params = None    # type: Sequence[str]
250    #: names of the fittable parameters
251    fixed = None              # type: Sequence[str]
252    # TODO: the attribute fixed is ill-named
253
254    # Axis labels
255    input_name = "Q"
256    input_unit = "A^{-1}"
257    output_name = "Intensity"
258    output_unit = "cm^{-1}"
259
260    #: default cutoff for polydispersity
261    cutoff = 1e-5
262
263    # Note: Use non-mutable values for class attributes to avoid errors
264    #: parameters that are not fitted
265    non_fittable = ()        # type: Sequence[str]
266
267    #: True if model should appear as a structure factor
268    is_structure_factor = False
269    #: True if model should appear as a form factor
270    is_form_factor = False
271    #: True if model has multiplicity
272    is_multiplicity_model = False
273    #: Mulitplicity information
274    multiplicity_info = None # type: MultiplicityInfoType
275
276    # Per-instance variables
277    #: parameter {name: value} mapping
278    params = None      # type: Dict[str, float]
279    #: values for dispersion width, npts, nsigmas and type
280    dispersion = None  # type: Dict[str, Any]
281    #: units and limits for each parameter
[60f03de]282    details = None     # type: Dict[str, Sequence[Any]]
283    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]284    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]285    multiplicity = None     # type: Optional[int]
[04dc697]286    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
287    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]288
289    def __init__(self, multiplicity=None):
[04dc697]290        # type: (Optional[int]) -> None
[2622b3f]291
[04045f4]292        # TODO: _persistency_dict to persistency_dict throughout sasview
293        # TODO: refactor multiplicity to encompass variants
294        # TODO: dispersion should be a class
[fa5fd8d]295        # TODO: refactor multiplicity info
296        # TODO: separate profile view from multiplicity
297        # The button label, x and y axis labels and scale need to be under
298        # the control of the model, not the fit page.  Maximum flexibility,
299        # the fit page would supply the canvas and the profile could plot
300        # how it wants, but this assumes matplotlib.  Next level is that
301        # we provide some sort of data description including title, labels
302        # and lines to plot.
303
[04045f4]304        # Get the list of hidden parameters given the mulitplicity
305        # Don't include multiplicity in the list of parameters
[fa5fd8d]306        self.multiplicity = multiplicity
[04045f4]307        if multiplicity is not None:
308            hidden = self._model_info.get_hidden_parameters(multiplicity)
309            hidden |= set([self.multiplicity_info.control])
310        else:
311            hidden = set()
312
[04dc697]313        self._persistency_dict = {}
[fa5fd8d]314        self.params = collections.OrderedDict()
[b3a85cd]315        self.dispersion = collections.OrderedDict()
[fa5fd8d]316        self.details = {}
[04045f4]317        for p in self._model_info.parameters.user_parameters():
318            if p.name in hidden:
[fa5fd8d]319                continue
[fcd7bbd]320            self.params[p.name] = p.default
[fa5fd8d]321            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]322            if p.polydisperse:
[fa5fd8d]323                self.details[p.id+".width"] = [
324                    "", 0.0, 1.0 if p.relative_pd else np.inf
325                ]
[fb5914f]326                self.dispersion[p.name] = {
327                    'width': 0,
328                    'npts': 35,
329                    'nsigmas': 3,
330                    'type': 'gaussian',
331                }
[ce27e21]332
[de97440]333    def __get_state__(self):
[fa5fd8d]334        # type: () -> Dict[str, Any]
[de97440]335        state = self.__dict__.copy()
[4d76711]336        state.pop('_model')
[de97440]337        # May need to reload model info on set state since it has pointers
338        # to python implementations of Iq, etc.
339        #state.pop('_model_info')
340        return state
341
342    def __set_state__(self, state):
[fa5fd8d]343        # type: (Dict[str, Any]) -> None
[de97440]344        self.__dict__ = state
[fb5914f]345        self._model = None
[de97440]346
[ce27e21]347    def __str__(self):
[fa5fd8d]348        # type: () -> str
[ce27e21]349        """
350        :return: string representation
351        """
352        return self.name
353
354    def is_fittable(self, par_name):
[fa5fd8d]355        # type: (str) -> bool
[ce27e21]356        """
357        Check if a given parameter is fittable or not
358
359        :param par_name: the parameter name to check
360        """
[e758662]361        return par_name in self.fixed
[ce27e21]362        #For the future
363        #return self.params[str(par_name)].is_fittable()
364
365
366    def getProfile(self):
[fa5fd8d]367        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]368        """
369        Get SLD profile
370
371        : return: (z, beta) where z is a list of depth of the transition points
372                beta is a list of the corresponding SLD values
373        """
[745b7bb]374        args = {} # type: Dict[str, Any]
[fa5fd8d]375        for p in self._model_info.parameters.kernel_parameters:
376            if p.id == self.multiplicity_info.control:
[745b7bb]377                value = float(self.multiplicity)
[fa5fd8d]378            elif p.length == 1:
[745b7bb]379                value = self.params.get(p.id, np.NaN)
[fa5fd8d]380            else:
[745b7bb]381                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]382                                  for k in range(1, p.length+1)])
[745b7bb]383            args[p.id] = value
384
[e7fe459]385        x, y = self._model_info.profile(**args)
386        return x, 1e-6*y
[ce27e21]387
388    def setParam(self, name, value):
[fa5fd8d]389        # type: (str, float) -> None
[ce27e21]390        """
391        Set the value of a model parameter
392
393        :param name: name of the parameter
394        :param value: value of the parameter
395
396        """
397        # Look for dispersion parameters
398        toks = name.split('.')
[de0c4ba]399        if len(toks) == 2:
[ce27e21]400            for item in self.dispersion.keys():
[e758662]401                if item == toks[0]:
[ce27e21]402                    for par in self.dispersion[item]:
[e758662]403                        if par == toks[1]:
[ce27e21]404                            self.dispersion[item][par] = value
405                            return
406        else:
407            # Look for standard parameter
408            for item in self.params.keys():
[e758662]409                if item == name:
[ce27e21]410                    self.params[item] = value
411                    return
412
[63b32bb]413        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]414
415    def getParam(self, name):
[fa5fd8d]416        # type: (str) -> float
[ce27e21]417        """
418        Set the value of a model parameter
419
420        :param name: name of the parameter
421
422        """
423        # Look for dispersion parameters
424        toks = name.split('.')
[de0c4ba]425        if len(toks) == 2:
[ce27e21]426            for item in self.dispersion.keys():
[e758662]427                if item == toks[0]:
[ce27e21]428                    for par in self.dispersion[item]:
[e758662]429                        if par == toks[1]:
[ce27e21]430                            return self.dispersion[item][par]
431        else:
432            # Look for standard parameter
433            for item in self.params.keys():
[e758662]434                if item == name:
[ce27e21]435                    return self.params[item]
436
[63b32bb]437        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]438
439    def getParamList(self):
[04dc697]440        # type: () -> Sequence[str]
[ce27e21]441        """
442        Return a list of all available parameters for the model
443        """
[04dc697]444        param_list = list(self.params.keys())
[ce27e21]445        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]446        param_list.extend(self.getDispParamList())
447        return param_list
[ce27e21]448
449    def getDispParamList(self):
[04dc697]450        # type: () -> Sequence[str]
[ce27e21]451        """
[fb5914f]452        Return a list of polydispersity parameters for the model
[ce27e21]453        """
[1780d59]454        # TODO: fix test so that parameter order doesn't matter
[56b2687]455        ret = ['%s.%s' % (p.name, ext)
[6d6508e]456               for p in self._model_info.parameters.user_parameters()
[fb5914f]457               for ext in ('npts', 'nsigmas', 'width')
458               if p.polydisperse]
[9404dd3]459        #print(ret)
[1780d59]460        return ret
[ce27e21]461
462    def clone(self):
[04dc697]463        # type: () -> "SasviewModel"
[ce27e21]464        """ Return a identical copy of self """
465        return deepcopy(self)
466
467    def run(self, x=0.0):
[fa5fd8d]468        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]469        """
470        Evaluate the model
471
472        :param x: input q, or [q,phi]
473
474        :return: scattering function P(q)
475
476        **DEPRECATED**: use calculate_Iq instead
477        """
[de0c4ba]478        if isinstance(x, (list, tuple)):
[3c56da87]479            # pylint: disable=unpacking-non-sequence
[ce27e21]480            q, phi = x
[60f03de]481            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]482        else:
[60f03de]483            return self.calculate_Iq([x])[0]
[ce27e21]484
485
486    def runXY(self, x=0.0):
[fa5fd8d]487        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]488        """
489        Evaluate the model in cartesian coordinates
490
491        :param x: input q, or [qx, qy]
492
493        :return: scattering function P(q)
494
495        **DEPRECATED**: use calculate_Iq instead
496        """
[de0c4ba]497        if isinstance(x, (list, tuple)):
[60f03de]498            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]499        else:
[60f03de]500            return self.calculate_Iq([x])[0]
[ce27e21]501
502    def evalDistribution(self, qdist):
[04dc697]503        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]504        r"""
[ce27e21]505        Evaluate a distribution of q-values.
506
[d138d43]507        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]508
[d138d43]509        * For 1D, a numpy array is expected as input
[ce27e21]510
[d138d43]511        ::
[ce27e21]512
[d138d43]513            evalDistribution(q)
[ce27e21]514
[d138d43]515          where *q* is a numpy array.
[ce27e21]516
[d138d43]517        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]518
[d138d43]519        ::
[ce27e21]520
[d138d43]521              qx = [ qx[0], qx[1], qx[2], ....]
522              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]523
[d138d43]524        If the model is 1D only, then
[ce27e21]525
[d138d43]526        .. math::
[ce27e21]527
[d138d43]528            q = \sqrt{q_x^2+q_y^2}
[ce27e21]529
530        """
[de0c4ba]531        if isinstance(qdist, (list, tuple)):
[ce27e21]532            # Check whether we have a list of ndarrays [qx,qy]
533            qx, qy = qdist
[6d6508e]534            if not self._model_info.parameters.has_2d:
[de0c4ba]535                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]536            else:
537                return self.calculate_Iq(qx, qy)
[ce27e21]538
539        elif isinstance(qdist, np.ndarray):
540            # We have a simple 1D distribution of q-values
541            return self.calculate_Iq(qdist)
542
543        else:
[3c56da87]544            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
545                            % type(qdist))
[ce27e21]546
[fa5fd8d]547    def calculate_Iq(self, qx, qy=None):
548        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]549        """
550        Calculate Iq for one set of q with the current parameters.
551
552        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
553
554        This should NOT be used for fitting since it copies the *q* vectors
555        to the card for each evaluation.
556        """
[6a0d6aa]557        #core.HAVE_OPENCL = False
[fb5914f]558        if self._model is None:
[d2bb604]559            self._model = core.build_model(self._model_info)
[fa5fd8d]560        if qy is not None:
561            q_vectors = [np.asarray(qx), np.asarray(qy)]
562        else:
563            q_vectors = [np.asarray(qx)]
[a738209]564        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]565        parameters = self._model_info.parameters
566        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[bde38b5]567        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]568        #call_details.show()
569        #print("pairs", pairs)
570        #print("params", self.params)
571        #print("values", values)
572        #print("is_mag", is_magnetic)
[6a0d6aa]573        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]574                            magnetic=is_magnetic)
[a738209]575        calculator.release()
[ce27e21]576        return result
577
578    def calculate_ER(self):
[fa5fd8d]579        # type: () -> float
[ce27e21]580        """
581        Calculate the effective radius for P(q)*S(q)
582
583        :return: the value of the effective radius
584        """
[4bfd277]585        if self._model_info.ER is None:
[ce27e21]586            return 1.0
587        else:
[4bfd277]588            value, weight = self._dispersion_mesh()
589            fv = self._model_info.ER(*value)
[9404dd3]590            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]591            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]592
593    def calculate_VR(self):
[fa5fd8d]594        # type: () -> float
[ce27e21]595        """
596        Calculate the volf ratio for P(q)*S(q)
597
598        :return: the value of the volf ratio
599        """
[4bfd277]600        if self._model_info.VR is None:
[ce27e21]601            return 1.0
602        else:
[4bfd277]603            value, weight = self._dispersion_mesh()
604            whole, part = self._model_info.VR(*value)
605            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]606
607    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]608        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]609        """
610        Set the dispersion object for a model parameter
611
612        :param parameter: name of the parameter [string]
613        :param dispersion: dispersion object of type Dispersion
614        """
[fa800e72]615        if parameter in self.params:
[1780d59]616            # TODO: Store the disperser object directly in the model.
[56b2687]617            # The current method of relying on the sasview GUI to
[fa800e72]618            # remember them is kind of funky.
[1780d59]619            # Note: can't seem to get disperser parameters from sasview
620            # (1) Could create a sasview model that has not yet # been
621            # converted, assign the disperser to one of its polydisperse
622            # parameters, then retrieve the disperser parameters from the
623            # sasview model.  (2) Could write a disperser parameter retriever
624            # in sasview.  (3) Could modify sasview to use sasmodels.weights
625            # dispersers.
626            # For now, rely on the fact that the sasview only ever uses
627            # new dispersers in the set_dispersion call and create a new
628            # one instead of trying to assign parameters.
[99f446a]629            dispersion = weights.MODELS[dispersion.type]()
[ce27e21]630            self.dispersion[parameter] = dispersion.get_pars()
631        else:
632            raise ValueError("%r is not a dispersity or orientation parameter")
633
[aa4946b]634    def _dispersion_mesh(self):
[fa5fd8d]635        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]636        """
637        Create a mesh grid of dispersion parameters and weights.
638
639        Returns [p1,p2,...],w where pj is a vector of values for parameter j
640        and w is a vector containing the products for weights for each
641        parameter set in the vector.
642        """
[4bfd277]643        pars = [self._get_weights(p)
644                for p in self._model_info.parameters.call_parameters
645                if p.type == 'volume']
[9eb3632]646        return dispersion_mesh(self._model_info, pars)
[ce27e21]647
648    def _get_weights(self, par):
[fa5fd8d]649        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]650        """
[fb5914f]651        Return dispersion weights for parameter
[de0c4ba]652        """
[fa5fd8d]653        if par.name not in self.params:
654            if par.name == self.multiplicity_info.control:
[4edec6f]655                return [self.multiplicity], [1.0]
[fa5fd8d]656            else:
[4edec6f]657                return [np.NaN], [1.0]
[fa5fd8d]658        elif par.polydisperse:
[fb5914f]659            dis = self.dispersion[par.name]
660            value, weight = weights.get_weights(
661                dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
662                self.params[par.name], par.limits, par.relative_pd)
663            return value, weight / np.sum(weight)
664        else:
[4edec6f]665            return [self.params[par.name]], [1.0]
[ce27e21]666
[fb5914f]667def test_model():
[fa5fd8d]668    # type: () -> float
[4d76711]669    """
670    Test that a sasview model (cylinder) can be run.
671    """
672    Cylinder = _make_standard_model('cylinder')
[fb5914f]673    cylinder = Cylinder()
[b32dafd]674    return cylinder.evalDistribution([0.1, 0.1])
[de97440]675
[04045f4]676def test_rpa():
677    # type: () -> float
678    """
679    Test that a sasview model (cylinder) can be run.
680    """
681    RPA = _make_standard_model('rpa')
682    rpa = RPA(3)
[b32dafd]683    return rpa.evalDistribution([0.1, 0.1])
[04045f4]684
[4d76711]685
686def test_model_list():
[fa5fd8d]687    # type: () -> None
[4d76711]688    """
689    Make sure that all models build as sasview models.
690    """
691    from .exception import annotate_exception
692    for name in core.list_models():
693        try:
694            _make_standard_model(name)
695        except:
696            annotate_exception("when loading "+name)
697            raise
698
[c95dfc63]699def test_old_name():
700    # type: () -> None
701    """
702    Load and run cylinder model from sas.models.CylinderModel
703    """
704    if not SUPPORT_OLD_STYLE_PLUGINS:
705        return
706    try:
707        # if sasview is not on the path then don't try to test it
708        import sas
709    except ImportError:
710        return
711    load_standard_models()
712    from sas.models.CylinderModel import CylinderModel
713    CylinderModel().evalDistribution([0.1, 0.1])
714
[fb5914f]715if __name__ == "__main__":
[ea05c87]716    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.