source: sasmodels/sasmodels/sasview_model.py @ 1a6cd57

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 1a6cd57 was 1a6cd57, checked in by jhbakker, 8 years ago

commit of stuff from master from fast-merge

  • Property mode set to 100644
File size: 26.0 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[9457498]17from os.path import basename, splitext
[ce27e21]18
[7ae2b7f]19import numpy as np  # type: ignore
[ce27e21]20
[aa4946b]21from . import core
[4d76711]22from . import custom
[72a081d]23from . import generate
[fb5914f]24from . import weights
[6d6508e]25from . import modelinfo
[bde38b5]26from .details import make_kernel_args, dispersion_mesh
[ff7119b]27
[fa5fd8d]28try:
[60f03de]29    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]30    from .modelinfo import ModelInfo, Parameter
31    from .kernel import KernelModel
32    MultiplicityInfoType = NamedTuple(
33        'MuliplicityInfo',
34        [("number", int), ("control", str), ("choices", List[str]),
35         ("x_axis_label", str)])
[60f03de]36    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]37except ImportError:
38    pass
39
[c95dfc63]40SUPPORT_OLD_STYLE_PLUGINS = True
41
42def _register_old_models():
43    # type: () -> None
44    """
45    Place the new models into sasview under the old names.
46
47    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
48    is available to the plugin modules.
49    """
50    import sys
51    import sas
52    import sas.sascalc.fit
53    sys.modules['sas.models'] = sas.sascalc.fit
54    sas.models = sas.sascalc.fit
55
56    import sas.models
57    from sasmodels.conversion_table import CONVERSION_TABLE
58    for new_name, conversion in CONVERSION_TABLE.items():
59        old_name = conversion[0]
[fd19811]60        module_attrs = {old_name: find_model(new_name)}
[c95dfc63]61        ConstructedModule = type(old_name, (), module_attrs)
62        old_path = 'sas.models.' + old_name
63        setattr(sas.models, old_path, ConstructedModule)
64        sys.modules[old_path] = ConstructedModule
65
[9457498]66
[fa5fd8d]67# TODO: separate x_axis_label from multiplicity info
68MultiplicityInfo = collections.namedtuple(
69    'MultiplicityInfo',
70    ["number", "control", "choices", "x_axis_label"],
71)
72
[92d38285]73MODELS = {}
74def find_model(modelname):
[b32dafd]75    # type: (str) -> SasviewModelType
76    """
77    Find a model by name.  If the model name ends in py, try loading it from
78    custom models, otherwise look for it in the list of builtin models.
79    """
[92d38285]80    # TODO: used by sum/product model to load an existing model
81    # TODO: doesn't handle custom models properly
82    if modelname.endswith('.py'):
83        return load_custom_model(modelname)
84    elif modelname in MODELS:
85        return MODELS[modelname]
86    else:
87        raise ValueError("unknown model %r"%modelname)
88
[56b2687]89
[fa5fd8d]90# TODO: figure out how to say that the return type is a subclass
[4d76711]91def load_standard_models():
[60f03de]92    # type: () -> List[SasviewModelType]
[4d76711]93    """
94    Load and return the list of predefined models.
95
96    If there is an error loading a model, then a traceback is logged and the
97    model is not returned.
98    """
99    models = []
100    for name in core.list_models():
101        try:
[92d38285]102            MODELS[name] = _make_standard_model(name)
103            models.append(MODELS[name])
[ee8f734]104        except Exception:
[4d76711]105            logging.error(traceback.format_exc())
[c95dfc63]106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
[4d76711]109    return models
[de97440]110
[4d76711]111
112def load_custom_model(path):
[60f03de]113    # type: (str) -> SasviewModelType
[4d76711]114    """
115    Load a custom model given the model path.
[ff7119b]116    """
[4d76711]117    kernel_module = custom.load_custom_kernel_module(path)
[92d38285]118    try:
119        model = kernel_module.Model
[9457498]120        # Old style models do not set the name in the class attributes, so
121        # set it here; this name will be overridden when the object is created
122        # with an instance variable that has the same value.
123        if model.name == "":
124            model.name = splitext(basename(path))[0]
[1a6cd57]125        if not hasattr(model, 'filename'):
126            model.filename = kernel_module.__file__
127            # For old models, treat .pyc and .py files interchangeably.
128            # This is needed because of the Sum|Multi(p1,p2) types of models
129            # and the convoluted way in which they are created.
130            if model.filename.endswith(".py"):
131                logging.info("Loading %s as .pyc", model.filename)
132                model.filename = model.filename+'c'
133        if not hasattr(model, 'id'):
134            model.id = splitext(basename(model.filename))[0]
[92d38285]135    except AttributeError:
[56b2687]136        model_info = modelinfo.make_model_info(kernel_module)
[92d38285]137        model = _make_model_from_info(model_info)
[1a6cd57]138
139    # If a model name already exists and we are loading a different model,
140    # use the model file name as the model name.
141    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
142        _previous_name = model.name
143        model.name = model.id
144       
145        # If the new model name is still in the model list (for instance,
146        # if we put a cylinder.py in our plug-in directory), then append
147        # an identifier.
148        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
149            model.name = model.id + '_user'
150        logging.info("Model %s already exists: using %s [%s]", _previous_name, model.name, model.filename)
151
[92d38285]152    MODELS[model.name] = model
153    return model
[4d76711]154
[87985ca]155
[4d76711]156def _make_standard_model(name):
[60f03de]157    # type: (str) -> SasviewModelType
[ff7119b]158    """
[4d76711]159    Load the sasview model defined by *name*.
[72a081d]160
[4d76711]161    *name* can be a standard model name or a path to a custom model.
[87985ca]162
[4d76711]163    Returns a class that can be used directly as a sasview model.
[ff7119b]164    """
[4d76711]165    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]166    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]167    return _make_model_from_info(model_info)
[72a081d]168
169
[4d76711]170def _make_model_from_info(model_info):
[60f03de]171    # type: (ModelInfo) -> SasviewModelType
[4d76711]172    """
173    Convert *model_info* into a SasView model wrapper.
174    """
[fa5fd8d]175    def __init__(self, multiplicity=None):
176        SasviewModel.__init__(self, multiplicity=multiplicity)
177    attrs = _generate_model_attributes(model_info)
178    attrs['__init__'] = __init__
[1a6cd57]179    attrs['filename'] = model_info.filename
[60f03de]180    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
[ce27e21]181    return ConstructedModel
182
[fa5fd8d]183def _generate_model_attributes(model_info):
184    # type: (ModelInfo) -> Dict[str, Any]
185    """
186    Generate the class attributes for the model.
187
188    This should include all the information necessary to query the model
189    details so that you do not need to instantiate a model to query it.
190
191    All the attributes should be immutable to avoid accidents.
192    """
193
194    # TODO: allow model to override axis labels input/output name/unit
195
[a18c5b3]196    # Process multiplicity
[fa5fd8d]197    non_fittable = []  # type: List[str]
[04045f4]198    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
199    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]200    for p in model_info.parameters.kernel_parameters:
[04045f4]201        if p.name == model_info.control:
[fa5fd8d]202            non_fittable.append(p.name)
[04045f4]203            variants = MultiplicityInfo(
[ce176ca]204                len(p.choices) if p.choices else int(p.limits[1]),
205                p.name, p.choices, xlabel
[fa5fd8d]206            )
207            break
208
[50ec515]209    # Only a single drop-down list parameter available
210    fun_list = []
211    for p in model_info.parameters.kernel_parameters:
212        if p.choices:
213            fun_list = p.choices
214            if p.length > 1:
215                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
216            break
217
[a18c5b3]218    # Organize parameter sets
[fa5fd8d]219    orientation_params = []
220    magnetic_params = []
221    fixed = []
[a18c5b3]222    for p in model_info.parameters.user_parameters():
[fa5fd8d]223        if p.type == 'orientation':
224            orientation_params.append(p.name)
225            orientation_params.append(p.name+".width")
226            fixed.append(p.name+".width")
[32e3c9b]227        elif p.type == 'magnetic':
[fa5fd8d]228            orientation_params.append(p.name)
229            magnetic_params.append(p.name)
230            fixed.append(p.name+".width")
[a18c5b3]231
[32e3c9b]232
[a18c5b3]233    # Build class dictionary
234    attrs = {}  # type: Dict[str, Any]
235    attrs['_model_info'] = model_info
236    attrs['name'] = model_info.name
237    attrs['id'] = model_info.id
238    attrs['description'] = model_info.description
239    attrs['category'] = model_info.category
240    attrs['is_structure_factor'] = model_info.structure_factor
241    attrs['is_form_factor'] = model_info.ER is not None
242    attrs['is_multiplicity_model'] = variants[0] > 1
243    attrs['multiplicity_info'] = variants
[fa5fd8d]244    attrs['orientation_params'] = tuple(orientation_params)
245    attrs['magnetic_params'] = tuple(magnetic_params)
246    attrs['fixed'] = tuple(fixed)
247    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]248    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]249
250    return attrs
[4d76711]251
[ce27e21]252class SasviewModel(object):
253    """
254    Sasview wrapper for opencl/ctypes model.
255    """
[fa5fd8d]256    # Model parameters for the specific model are set in the class constructor
257    # via the _generate_model_attributes function, which subclasses
258    # SasviewModel.  They are included here for typing and documentation
259    # purposes.
260    _model = None       # type: KernelModel
261    _model_info = None  # type: ModelInfo
262    #: load/save name for the model
263    id = None           # type: str
264    #: display name for the model
265    name = None         # type: str
266    #: short model description
267    description = None  # type: str
268    #: default model category
269    category = None     # type: str
270
271    #: names of the orientation parameters in the order they appear
272    orientation_params = None # type: Sequence[str]
273    #: names of the magnetic parameters in the order they appear
274    magnetic_params = None    # type: Sequence[str]
275    #: names of the fittable parameters
276    fixed = None              # type: Sequence[str]
277    # TODO: the attribute fixed is ill-named
278
279    # Axis labels
280    input_name = "Q"
281    input_unit = "A^{-1}"
282    output_name = "Intensity"
283    output_unit = "cm^{-1}"
284
285    #: default cutoff for polydispersity
286    cutoff = 1e-5
287
288    # Note: Use non-mutable values for class attributes to avoid errors
289    #: parameters that are not fitted
290    non_fittable = ()        # type: Sequence[str]
291
292    #: True if model should appear as a structure factor
293    is_structure_factor = False
294    #: True if model should appear as a form factor
295    is_form_factor = False
296    #: True if model has multiplicity
297    is_multiplicity_model = False
298    #: Mulitplicity information
299    multiplicity_info = None # type: MultiplicityInfoType
300
301    # Per-instance variables
302    #: parameter {name: value} mapping
303    params = None      # type: Dict[str, float]
304    #: values for dispersion width, npts, nsigmas and type
305    dispersion = None  # type: Dict[str, Any]
306    #: units and limits for each parameter
[60f03de]307    details = None     # type: Dict[str, Sequence[Any]]
308    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]309    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]310    multiplicity = None     # type: Optional[int]
[04dc697]311    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
312    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]313
314    def __init__(self, multiplicity=None):
[04dc697]315        # type: (Optional[int]) -> None
[2622b3f]316
[04045f4]317        # TODO: _persistency_dict to persistency_dict throughout sasview
318        # TODO: refactor multiplicity to encompass variants
319        # TODO: dispersion should be a class
[fa5fd8d]320        # TODO: refactor multiplicity info
321        # TODO: separate profile view from multiplicity
322        # The button label, x and y axis labels and scale need to be under
323        # the control of the model, not the fit page.  Maximum flexibility,
324        # the fit page would supply the canvas and the profile could plot
325        # how it wants, but this assumes matplotlib.  Next level is that
326        # we provide some sort of data description including title, labels
327        # and lines to plot.
328
[04045f4]329        # Get the list of hidden parameters given the mulitplicity
330        # Don't include multiplicity in the list of parameters
[fa5fd8d]331        self.multiplicity = multiplicity
[04045f4]332        if multiplicity is not None:
333            hidden = self._model_info.get_hidden_parameters(multiplicity)
334            hidden |= set([self.multiplicity_info.control])
335        else:
336            hidden = set()
[1a6cd57]337        if self._model_info.structure_factor:
338            hidden.add('scale')
339            hidden.add('background')
340            self._model_info.parameters.defaults['background'] = 0.
[04045f4]341
[04dc697]342        self._persistency_dict = {}
[fa5fd8d]343        self.params = collections.OrderedDict()
[b3a85cd]344        self.dispersion = collections.OrderedDict()
[fa5fd8d]345        self.details = {}
[04045f4]346        for p in self._model_info.parameters.user_parameters():
347            if p.name in hidden:
[fa5fd8d]348                continue
[fcd7bbd]349            self.params[p.name] = p.default
[fa5fd8d]350            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]351            if p.polydisperse:
[fa5fd8d]352                self.details[p.id+".width"] = [
353                    "", 0.0, 1.0 if p.relative_pd else np.inf
354                ]
[fb5914f]355                self.dispersion[p.name] = {
356                    'width': 0,
357                    'npts': 35,
358                    'nsigmas': 3,
359                    'type': 'gaussian',
360                }
[ce27e21]361
[de97440]362    def __get_state__(self):
[fa5fd8d]363        # type: () -> Dict[str, Any]
[de97440]364        state = self.__dict__.copy()
[4d76711]365        state.pop('_model')
[de97440]366        # May need to reload model info on set state since it has pointers
367        # to python implementations of Iq, etc.
368        #state.pop('_model_info')
369        return state
370
371    def __set_state__(self, state):
[fa5fd8d]372        # type: (Dict[str, Any]) -> None
[de97440]373        self.__dict__ = state
[fb5914f]374        self._model = None
[de97440]375
[ce27e21]376    def __str__(self):
[fa5fd8d]377        # type: () -> str
[ce27e21]378        """
379        :return: string representation
380        """
381        return self.name
382
383    def is_fittable(self, par_name):
[fa5fd8d]384        # type: (str) -> bool
[ce27e21]385        """
386        Check if a given parameter is fittable or not
387
388        :param par_name: the parameter name to check
389        """
[e758662]390        return par_name in self.fixed
[ce27e21]391        #For the future
392        #return self.params[str(par_name)].is_fittable()
393
394
395    def getProfile(self):
[fa5fd8d]396        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]397        """
398        Get SLD profile
399
400        : return: (z, beta) where z is a list of depth of the transition points
401                beta is a list of the corresponding SLD values
402        """
[745b7bb]403        args = {} # type: Dict[str, Any]
[fa5fd8d]404        for p in self._model_info.parameters.kernel_parameters:
405            if p.id == self.multiplicity_info.control:
[745b7bb]406                value = float(self.multiplicity)
[fa5fd8d]407            elif p.length == 1:
[745b7bb]408                value = self.params.get(p.id, np.NaN)
[fa5fd8d]409            else:
[745b7bb]410                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]411                                  for k in range(1, p.length+1)])
[745b7bb]412            args[p.id] = value
413
[e7fe459]414        x, y = self._model_info.profile(**args)
415        return x, 1e-6*y
[ce27e21]416
417    def setParam(self, name, value):
[fa5fd8d]418        # type: (str, float) -> None
[ce27e21]419        """
420        Set the value of a model parameter
421
422        :param name: name of the parameter
423        :param value: value of the parameter
424
425        """
426        # Look for dispersion parameters
427        toks = name.split('.')
[de0c4ba]428        if len(toks) == 2:
[ce27e21]429            for item in self.dispersion.keys():
[e758662]430                if item == toks[0]:
[ce27e21]431                    for par in self.dispersion[item]:
[e758662]432                        if par == toks[1]:
[ce27e21]433                            self.dispersion[item][par] = value
434                            return
435        else:
436            # Look for standard parameter
437            for item in self.params.keys():
[e758662]438                if item == name:
[ce27e21]439                    self.params[item] = value
440                    return
441
[63b32bb]442        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]443
444    def getParam(self, name):
[fa5fd8d]445        # type: (str) -> float
[ce27e21]446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450
451        """
452        # Look for dispersion parameters
453        toks = name.split('.')
[de0c4ba]454        if len(toks) == 2:
[ce27e21]455            for item in self.dispersion.keys():
[e758662]456                if item == toks[0]:
[ce27e21]457                    for par in self.dispersion[item]:
[e758662]458                        if par == toks[1]:
[ce27e21]459                            return self.dispersion[item][par]
460        else:
461            # Look for standard parameter
462            for item in self.params.keys():
[e758662]463                if item == name:
[ce27e21]464                    return self.params[item]
465
[63b32bb]466        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]467
468    def getParamList(self):
[04dc697]469        # type: () -> Sequence[str]
[ce27e21]470        """
471        Return a list of all available parameters for the model
472        """
[04dc697]473        param_list = list(self.params.keys())
[ce27e21]474        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]475        param_list.extend(self.getDispParamList())
476        return param_list
[ce27e21]477
478    def getDispParamList(self):
[04dc697]479        # type: () -> Sequence[str]
[ce27e21]480        """
[fb5914f]481        Return a list of polydispersity parameters for the model
[ce27e21]482        """
[1780d59]483        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]484        ret = ['%s.%s' % (p_name, ext)
485               for p_name in self.dispersion.keys()
486               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]487        #print(ret)
[1780d59]488        return ret
[ce27e21]489
490    def clone(self):
[04dc697]491        # type: () -> "SasviewModel"
[ce27e21]492        """ Return a identical copy of self """
493        return deepcopy(self)
494
495    def run(self, x=0.0):
[fa5fd8d]496        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]497        """
498        Evaluate the model
499
500        :param x: input q, or [q,phi]
501
502        :return: scattering function P(q)
503
504        **DEPRECATED**: use calculate_Iq instead
505        """
[de0c4ba]506        if isinstance(x, (list, tuple)):
[3c56da87]507            # pylint: disable=unpacking-non-sequence
[ce27e21]508            q, phi = x
[60f03de]509            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]510        else:
[60f03de]511            return self.calculate_Iq([x])[0]
[ce27e21]512
513
514    def runXY(self, x=0.0):
[fa5fd8d]515        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]516        """
517        Evaluate the model in cartesian coordinates
518
519        :param x: input q, or [qx, qy]
520
521        :return: scattering function P(q)
522
523        **DEPRECATED**: use calculate_Iq instead
524        """
[de0c4ba]525        if isinstance(x, (list, tuple)):
[60f03de]526            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]527        else:
[60f03de]528            return self.calculate_Iq([x])[0]
[ce27e21]529
530    def evalDistribution(self, qdist):
[04dc697]531        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]532        r"""
[ce27e21]533        Evaluate a distribution of q-values.
534
[d138d43]535        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]536
[d138d43]537        * For 1D, a numpy array is expected as input
[ce27e21]538
[d138d43]539        ::
[ce27e21]540
[d138d43]541            evalDistribution(q)
[ce27e21]542
[d138d43]543          where *q* is a numpy array.
[ce27e21]544
[d138d43]545        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]546
[d138d43]547        ::
[ce27e21]548
[d138d43]549              qx = [ qx[0], qx[1], qx[2], ....]
550              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]551
[d138d43]552        If the model is 1D only, then
[ce27e21]553
[d138d43]554        .. math::
[ce27e21]555
[d138d43]556            q = \sqrt{q_x^2+q_y^2}
[ce27e21]557
558        """
[de0c4ba]559        if isinstance(qdist, (list, tuple)):
[ce27e21]560            # Check whether we have a list of ndarrays [qx,qy]
561            qx, qy = qdist
[6d6508e]562            if not self._model_info.parameters.has_2d:
[de0c4ba]563                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]564            else:
565                return self.calculate_Iq(qx, qy)
[ce27e21]566
567        elif isinstance(qdist, np.ndarray):
568            # We have a simple 1D distribution of q-values
569            return self.calculate_Iq(qdist)
570
571        else:
[3c56da87]572            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
573                            % type(qdist))
[ce27e21]574
[fa5fd8d]575    def calculate_Iq(self, qx, qy=None):
576        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]577        """
578        Calculate Iq for one set of q with the current parameters.
579
580        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
581
582        This should NOT be used for fitting since it copies the *q* vectors
583        to the card for each evaluation.
584        """
[6a0d6aa]585        #core.HAVE_OPENCL = False
[fb5914f]586        if self._model is None:
[d2bb604]587            self._model = core.build_model(self._model_info)
[fa5fd8d]588        if qy is not None:
589            q_vectors = [np.asarray(qx), np.asarray(qy)]
590        else:
591            q_vectors = [np.asarray(qx)]
[a738209]592        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]593        parameters = self._model_info.parameters
594        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]595        #weights.plot_weights(self._model_info, pairs)
[bde38b5]596        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]597        #call_details.show()
598        #print("pairs", pairs)
599        #print("params", self.params)
600        #print("values", values)
601        #print("is_mag", is_magnetic)
[6a0d6aa]602        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]603                            magnetic=is_magnetic)
[a738209]604        calculator.release()
[1a6cd57]605        self._model.release()
[ce27e21]606        return result
607
608    def calculate_ER(self):
[fa5fd8d]609        # type: () -> float
[ce27e21]610        """
611        Calculate the effective radius for P(q)*S(q)
612
613        :return: the value of the effective radius
614        """
[4bfd277]615        if self._model_info.ER is None:
[ce27e21]616            return 1.0
617        else:
[4bfd277]618            value, weight = self._dispersion_mesh()
619            fv = self._model_info.ER(*value)
[9404dd3]620            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]621            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]622
623    def calculate_VR(self):
[fa5fd8d]624        # type: () -> float
[ce27e21]625        """
626        Calculate the volf ratio for P(q)*S(q)
627
628        :return: the value of the volf ratio
629        """
[4bfd277]630        if self._model_info.VR is None:
[ce27e21]631            return 1.0
632        else:
[4bfd277]633            value, weight = self._dispersion_mesh()
634            whole, part = self._model_info.VR(*value)
635            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]636
637    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]638        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]639        """
640        Set the dispersion object for a model parameter
641
642        :param parameter: name of the parameter [string]
643        :param dispersion: dispersion object of type Dispersion
644        """
[fa800e72]645        if parameter in self.params:
[1780d59]646            # TODO: Store the disperser object directly in the model.
[56b2687]647            # The current method of relying on the sasview GUI to
[fa800e72]648            # remember them is kind of funky.
[1780d59]649            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]650            # (1) Could create a sasview model that has not yet been
[1780d59]651            # converted, assign the disperser to one of its polydisperse
652            # parameters, then retrieve the disperser parameters from the
[9c1a59c]653            # sasview model.
654            # (2) Could write a disperser parameter retriever in sasview.
655            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]656            # For now, rely on the fact that the sasview only ever uses
657            # new dispersers in the set_dispersion call and create a new
658            # one instead of trying to assign parameters.
[ce27e21]659            self.dispersion[parameter] = dispersion.get_pars()
660        else:
661            raise ValueError("%r is not a dispersity or orientation parameter")
662
[aa4946b]663    def _dispersion_mesh(self):
[fa5fd8d]664        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]665        """
666        Create a mesh grid of dispersion parameters and weights.
667
668        Returns [p1,p2,...],w where pj is a vector of values for parameter j
669        and w is a vector containing the products for weights for each
670        parameter set in the vector.
671        """
[4bfd277]672        pars = [self._get_weights(p)
673                for p in self._model_info.parameters.call_parameters
674                if p.type == 'volume']
[9eb3632]675        return dispersion_mesh(self._model_info, pars)
[ce27e21]676
677    def _get_weights(self, par):
[fa5fd8d]678        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]679        """
[fb5914f]680        Return dispersion weights for parameter
[de0c4ba]681        """
[fa5fd8d]682        if par.name not in self.params:
683            if par.name == self.multiplicity_info.control:
[4edec6f]684                return [self.multiplicity], [1.0]
[fa5fd8d]685            else:
[1a6cd57]686                # For hidden parameters use the default value.
687                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
688                return [value], [1.0]
[fa5fd8d]689        elif par.polydisperse:
[fb5914f]690            dis = self.dispersion[par.name]
[9c1a59c]691            if dis['type'] == 'array':
692                value, weight = dis['values'], dis['weights']
693            else:
694                value, weight = weights.get_weights(
695                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
696                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]697            return value, weight / np.sum(weight)
698        else:
[4edec6f]699            return [self.params[par.name]], [1.0]
[ce27e21]700
[fb5914f]701def test_model():
[fa5fd8d]702    # type: () -> float
[4d76711]703    """
704    Test that a sasview model (cylinder) can be run.
705    """
706    Cylinder = _make_standard_model('cylinder')
[fb5914f]707    cylinder = Cylinder()
[b32dafd]708    return cylinder.evalDistribution([0.1, 0.1])
[de97440]709
[1a6cd57]710def test_structure_factor():
711    # type: () -> float
712    """
713    Test that a sasview model (cylinder) can be run.
714    """
715    Model = _make_standard_model('hardsphere')
716    model = Model()
717    value = model.evalDistribution([0.1, 0.1])
718    if np.isnan(value):
719        raise ValueError("hardsphere returns null")
720
[04045f4]721def test_rpa():
722    # type: () -> float
723    """
724    Test that a sasview model (cylinder) can be run.
725    """
726    RPA = _make_standard_model('rpa')
727    rpa = RPA(3)
[b32dafd]728    return rpa.evalDistribution([0.1, 0.1])
[04045f4]729
[4d76711]730
731def test_model_list():
[fa5fd8d]732    # type: () -> None
[4d76711]733    """
734    Make sure that all models build as sasview models.
735    """
736    from .exception import annotate_exception
737    for name in core.list_models():
738        try:
739            _make_standard_model(name)
740        except:
741            annotate_exception("when loading "+name)
742            raise
743
[c95dfc63]744def test_old_name():
745    # type: () -> None
746    """
747    Load and run cylinder model from sas.models.CylinderModel
748    """
749    if not SUPPORT_OLD_STYLE_PLUGINS:
750        return
751    try:
752        # if sasview is not on the path then don't try to test it
753        import sas
754    except ImportError:
755        return
756    load_standard_models()
757    from sas.models.CylinderModel import CylinderModel
758    CylinderModel().evalDistribution([0.1, 0.1])
759
[fb5914f]760if __name__ == "__main__":
[ea05c87]761    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.