source: sasmodels/sasmodels/sasview_model.py @ eb3fab6

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since eb3fab6 was eb3fab6, checked in by Gonzalez, Miguel <gonzalez@…>, 7 years ago

Fix interface problem appearing when a form factor that has multiplicity is combined with a structure factor

  • Property mode set to 100644
File size: 30.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33try:
34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
38        'MuliplicityInfo',
39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
41    SasviewModelType = Callable[[int], "SasviewModel"]
42except ImportError:
43    pass
44
45logger = logging.getLogger(__name__)
46
47calculation_lock = thread.allocate_lock()
48
49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
51SUPPORT_OLD_STYLE_PLUGINS = True
52
53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
64def find_model(modelname):
65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
79
80# TODO: figure out how to say that the return type is a subclass
81def load_standard_models():
82    # type: () -> List[SasviewModelType]
83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
91            MODELS[name] = _make_standard_model(name)
92        except Exception:
93            logger.error(traceback.format_exc())
94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
97    return list(MODELS.values())
98
99
100def load_custom_model(path):
101    # type: (str) -> SasviewModelType
102    """
103    Load a custom model given the model path.
104    """
105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
111    kernel_module = custom.load_custom_kernel_module(path)
112    if hasattr(kernel_module, 'Model'):
113        model = kernel_module.Model
114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
119        if not hasattr(model, 'filename'):
120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
123    else:
124        model_info = modelinfo.make_model_info(kernel_module)
125        model = make_model_from_info(model_info)
126    model.timestamp = getmtime(path)
127
128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
133
134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
141
142    MODELS[model.name] = model
143    MODEL_BY_PATH[path] = model
144    return model
145
146
147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
161def _make_standard_model(name):
162    # type: (str) -> SasviewModelType
163    """
164    Load the sasview model defined by *name*.
165
166    *name* can be a standard model name or a path to a custom model.
167
168    Returns a class that can be used directly as a sasview model.
169    """
170    kernel_module = generate.load_kernel_module(name)
171    model_info = modelinfo.make_model_info(kernel_module)
172    return make_model_from_info(model_info)
173
174
175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188
189    import sas.models
190    from sasmodels.conversion_table import CONVERSION_TABLE
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
206    ConstructedModel = make_model_from_info(model_info)
207    if form_factor.is_multiplicity_model:
208        ConstructedModel.is_multiplicity_model = True
209        ConstructedModel.multiplicity_info = form_factor.multiplicity_info
210        ConstructedModel.non_fittable = form_factor.non_fittable
211        return ConstructedModel(form_factor.multiplicity)
212    else:
213        return ConstructedModel()
214    return ConstructedModel()
215
216
217def _generate_model_attributes(model_info):
218    # type: (ModelInfo) -> Dict[str, Any]
219    """
220    Generate the class attributes for the model.
221
222    This should include all the information necessary to query the model
223    details so that you do not need to instantiate a model to query it.
224
225    All the attributes should be immutable to avoid accidents.
226    """
227
228    # TODO: allow model to override axis labels input/output name/unit
229
230    # Process multiplicity
231    non_fittable = []  # type: List[str]
232    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
233    variants = MultiplicityInfo(0, "", [], xlabel)
234    for p in model_info.parameters.kernel_parameters:
235        if p.name == model_info.control:
236            non_fittable.append(p.name)
237            variants = MultiplicityInfo(
238                len(p.choices) if p.choices else int(p.limits[1]),
239                p.name, p.choices, xlabel
240            )
241            break
242
243    # Only a single drop-down list parameter available
244    fun_list = []
245    for p in model_info.parameters.kernel_parameters:
246        if p.choices:
247            fun_list = p.choices
248            if p.length > 1:
249                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
250            break
251
252    # Organize parameter sets
253    orientation_params = []
254    magnetic_params = []
255    fixed = []
256    for p in model_info.parameters.user_parameters({}, is2d=True):
257        if p.type == 'orientation':
258            orientation_params.append(p.name)
259            orientation_params.append(p.name+".width")
260            fixed.append(p.name+".width")
261        elif p.type == 'magnetic':
262            orientation_params.append(p.name)
263            magnetic_params.append(p.name)
264            fixed.append(p.name+".width")
265
266
267    # Build class dictionary
268    attrs = {}  # type: Dict[str, Any]
269    attrs['_model_info'] = model_info
270    attrs['name'] = model_info.name
271    attrs['id'] = model_info.id
272    attrs['description'] = model_info.description
273    attrs['category'] = model_info.category
274    attrs['is_structure_factor'] = model_info.structure_factor
275    attrs['is_form_factor'] = model_info.ER is not None
276    attrs['is_multiplicity_model'] = variants[0] > 1
277    attrs['multiplicity_info'] = variants
278    attrs['orientation_params'] = tuple(orientation_params)
279    attrs['magnetic_params'] = tuple(magnetic_params)
280    attrs['fixed'] = tuple(fixed)
281    attrs['non_fittable'] = tuple(non_fittable)
282    attrs['fun_list'] = tuple(fun_list)
283
284    return attrs
285
286class SasviewModel(object):
287    """
288    Sasview wrapper for opencl/ctypes model.
289    """
290    # Model parameters for the specific model are set in the class constructor
291    # via the _generate_model_attributes function, which subclasses
292    # SasviewModel.  They are included here for typing and documentation
293    # purposes.
294    _model = None       # type: KernelModel
295    _model_info = None  # type: ModelInfo
296    #: load/save name for the model
297    id = None           # type: str
298    #: display name for the model
299    name = None         # type: str
300    #: short model description
301    description = None  # type: str
302    #: default model category
303    category = None     # type: str
304
305    #: names of the orientation parameters in the order they appear
306    orientation_params = None # type: List[str]
307    #: names of the magnetic parameters in the order they appear
308    magnetic_params = None    # type: List[str]
309    #: names of the fittable parameters
310    fixed = None              # type: List[str]
311    # TODO: the attribute fixed is ill-named
312
313    # Axis labels
314    input_name = "Q"
315    input_unit = "A^{-1}"
316    output_name = "Intensity"
317    output_unit = "cm^{-1}"
318
319    #: default cutoff for polydispersity
320    cutoff = 1e-5
321
322    # Note: Use non-mutable values for class attributes to avoid errors
323    #: parameters that are not fitted
324    non_fittable = ()        # type: Sequence[str]
325
326    #: True if model should appear as a structure factor
327    is_structure_factor = False
328    #: True if model should appear as a form factor
329    is_form_factor = False
330    #: True if model has multiplicity
331    is_multiplicity_model = False
332    #: Mulitplicity information
333    multiplicity_info = None # type: MultiplicityInfoType
334
335    # Per-instance variables
336    #: parameter {name: value} mapping
337    params = None      # type: Dict[str, float]
338    #: values for dispersion width, npts, nsigmas and type
339    dispersion = None  # type: Dict[str, Any]
340    #: units and limits for each parameter
341    details = None     # type: Dict[str, Sequence[Any]]
342    #                  # actual type is Dict[str, List[str, float, float]]
343    #: multiplicity value, or None if no multiplicity on the model
344    multiplicity = None     # type: Optional[int]
345    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
346    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
347
348    def __init__(self, multiplicity=None):
349        # type: (Optional[int]) -> None
350
351        # TODO: _persistency_dict to persistency_dict throughout sasview
352        # TODO: refactor multiplicity to encompass variants
353        # TODO: dispersion should be a class
354        # TODO: refactor multiplicity info
355        # TODO: separate profile view from multiplicity
356        # The button label, x and y axis labels and scale need to be under
357        # the control of the model, not the fit page.  Maximum flexibility,
358        # the fit page would supply the canvas and the profile could plot
359        # how it wants, but this assumes matplotlib.  Next level is that
360        # we provide some sort of data description including title, labels
361        # and lines to plot.
362
363        # Get the list of hidden parameters given the mulitplicity
364        # Don't include multiplicity in the list of parameters
365        self.multiplicity = multiplicity
366        if multiplicity is not None:
367            hidden = self._model_info.get_hidden_parameters(multiplicity)
368            hidden |= set([self.multiplicity_info.control])
369        else:
370            hidden = set()
371        if self._model_info.structure_factor:
372            hidden.add('scale')
373            hidden.add('background')
374            self._model_info.parameters.defaults['background'] = 0.
375
376        self._persistency_dict = {}
377        self.params = collections.OrderedDict()
378        self.dispersion = collections.OrderedDict()
379        self.details = {}
380        for p in self._model_info.parameters.user_parameters({}, is2d=True):
381            if p.name in hidden:
382                continue
383            self.params[p.name] = p.default
384            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
385            if p.polydisperse:
386                self.details[p.id+".width"] = [
387                    "", 0.0, 1.0 if p.relative_pd else np.inf
388                ]
389                self.dispersion[p.name] = {
390                    'width': 0,
391                    'npts': 35,
392                    'nsigmas': 3,
393                    'type': 'gaussian',
394                }
395
396    def __get_state__(self):
397        # type: () -> Dict[str, Any]
398        state = self.__dict__.copy()
399        state.pop('_model')
400        # May need to reload model info on set state since it has pointers
401        # to python implementations of Iq, etc.
402        #state.pop('_model_info')
403        return state
404
405    def __set_state__(self, state):
406        # type: (Dict[str, Any]) -> None
407        self.__dict__ = state
408        self._model = None
409
410    def __str__(self):
411        # type: () -> str
412        """
413        :return: string representation
414        """
415        return self.name
416
417    def is_fittable(self, par_name):
418        # type: (str) -> bool
419        """
420        Check if a given parameter is fittable or not
421
422        :param par_name: the parameter name to check
423        """
424        return par_name in self.fixed
425        #For the future
426        #return self.params[str(par_name)].is_fittable()
427
428
429    def getProfile(self):
430        # type: () -> (np.ndarray, np.ndarray)
431        """
432        Get SLD profile
433
434        : return: (z, beta) where z is a list of depth of the transition points
435                beta is a list of the corresponding SLD values
436        """
437        args = {} # type: Dict[str, Any]
438        for p in self._model_info.parameters.kernel_parameters:
439            if p.id == self.multiplicity_info.control:
440                value = float(self.multiplicity)
441            elif p.length == 1:
442                value = self.params.get(p.id, np.NaN)
443            else:
444                value = np.array([self.params.get(p.id+str(k), np.NaN)
445                                  for k in range(1, p.length+1)])
446            args[p.id] = value
447
448        x, y = self._model_info.profile(**args)
449        return x, 1e-6*y
450
451    def setParam(self, name, value):
452        # type: (str, float) -> None
453        """
454        Set the value of a model parameter
455
456        :param name: name of the parameter
457        :param value: value of the parameter
458
459        """
460        # Look for dispersion parameters
461        toks = name.split('.')
462        if len(toks) == 2:
463            for item in self.dispersion.keys():
464                if item == toks[0]:
465                    for par in self.dispersion[item]:
466                        if par == toks[1]:
467                            self.dispersion[item][par] = value
468                            return
469        else:
470            # Look for standard parameter
471            for item in self.params.keys():
472                if item == name:
473                    self.params[item] = value
474                    return
475
476        raise ValueError("Model does not contain parameter %s" % name)
477
478    def getParam(self, name):
479        # type: (str) -> float
480        """
481        Set the value of a model parameter
482
483        :param name: name of the parameter
484
485        """
486        # Look for dispersion parameters
487        toks = name.split('.')
488        if len(toks) == 2:
489            for item in self.dispersion.keys():
490                if item == toks[0]:
491                    for par in self.dispersion[item]:
492                        if par == toks[1]:
493                            return self.dispersion[item][par]
494        else:
495            # Look for standard parameter
496            for item in self.params.keys():
497                if item == name:
498                    return self.params[item]
499
500        raise ValueError("Model does not contain parameter %s" % name)
501
502    def getParamList(self):
503        # type: () -> Sequence[str]
504        """
505        Return a list of all available parameters for the model
506        """
507        param_list = list(self.params.keys())
508        # WARNING: Extending the list with the dispersion parameters
509        param_list.extend(self.getDispParamList())
510        return param_list
511
512    def getDispParamList(self):
513        # type: () -> Sequence[str]
514        """
515        Return a list of polydispersity parameters for the model
516        """
517        # TODO: fix test so that parameter order doesn't matter
518        ret = ['%s.%s' % (p_name, ext)
519               for p_name in self.dispersion.keys()
520               for ext in ('npts', 'nsigmas', 'width')]
521        #print(ret)
522        return ret
523
524    def clone(self):
525        # type: () -> "SasviewModel"
526        """ Return a identical copy of self """
527        return deepcopy(self)
528
529    def run(self, x=0.0):
530        # type: (Union[float, (float, float), List[float]]) -> float
531        """
532        Evaluate the model
533
534        :param x: input q, or [q,phi]
535
536        :return: scattering function P(q)
537
538        **DEPRECATED**: use calculate_Iq instead
539        """
540        if isinstance(x, (list, tuple)):
541            # pylint: disable=unpacking-non-sequence
542            q, phi = x
543            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
544        else:
545            return self.calculate_Iq([x])[0]
546
547
548    def runXY(self, x=0.0):
549        # type: (Union[float, (float, float), List[float]]) -> float
550        """
551        Evaluate the model in cartesian coordinates
552
553        :param x: input q, or [qx, qy]
554
555        :return: scattering function P(q)
556
557        **DEPRECATED**: use calculate_Iq instead
558        """
559        if isinstance(x, (list, tuple)):
560            return self.calculate_Iq([x[0]], [x[1]])[0]
561        else:
562            return self.calculate_Iq([x])[0]
563
564    def evalDistribution(self, qdist):
565        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
566        r"""
567        Evaluate a distribution of q-values.
568
569        :param qdist: array of q or a list of arrays [qx,qy]
570
571        * For 1D, a numpy array is expected as input
572
573        ::
574
575            evalDistribution(q)
576
577          where *q* is a numpy array.
578
579        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
580
581        ::
582
583              qx = [ qx[0], qx[1], qx[2], ....]
584              qy = [ qy[0], qy[1], qy[2], ....]
585
586        If the model is 1D only, then
587
588        .. math::
589
590            q = \sqrt{q_x^2+q_y^2}
591
592        """
593        if isinstance(qdist, (list, tuple)):
594            # Check whether we have a list of ndarrays [qx,qy]
595            qx, qy = qdist
596            if not self._model_info.parameters.has_2d:
597                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
598            else:
599                return self.calculate_Iq(qx, qy)
600
601        elif isinstance(qdist, np.ndarray):
602            # We have a simple 1D distribution of q-values
603            return self.calculate_Iq(qdist)
604
605        else:
606            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
607                            % type(qdist))
608
609    def calc_composition_models(self, qx):
610        """
611        returns parts of the composition model or None if not a composition
612        model.
613        """
614        # TODO: have calculate_Iq return the intermediates.
615        #
616        # The current interface causes calculate_Iq() to be called twice,
617        # once to get the combined result and again to get the intermediate
618        # results.  This is necessary for now.
619        # Long term, the solution is to change the interface to calculate_Iq
620        # so that it returns a results object containing all the bits:
621        #     the A, B, C, ... of the composition model (and any subcomponents?)
622        #     the P and S of the product model,
623        #     the combined model before resolution smearing,
624        #     the sasmodel before sesans conversion,
625        #     the oriented 2D model used to fit oriented usans data,
626        #     the final I(q),
627        #     ...
628        #
629        # Have the model calculator add all of these blindly to the data
630        # tree, and update the graphs which contain them.  The fitter
631        # needs to be updated to use the I(q) value only, ignoring the rest.
632        #
633        # The simple fix of returning the existing intermediate results
634        # will not work for a couple of reasons: (1) another thread may
635        # sneak in to compute its own results before calc_composition_models
636        # is called, and (2) calculate_Iq is currently called three times:
637        # once with q, once with q values before qmin and once with q values
638        # after q max.  Both of these should be addressed before
639        # replacing this code.
640        composition = self._model_info.composition
641        if composition and composition[0] == 'product': # only P*S for now
642            with calculation_lock:
643                self._calculate_Iq(qx)
644                return self._intermediate_results
645        else:
646            return None
647
648    def calculate_Iq(self, qx, qy=None):
649        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
650        """
651        Calculate Iq for one set of q with the current parameters.
652
653        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
654
655        This should NOT be used for fitting since it copies the *q* vectors
656        to the card for each evaluation.
657        """
658        ## uncomment the following when trying to debug the uncoordinated calls
659        ## to calculate_Iq
660        #if calculation_lock.locked():
661        #    logger.info("calculation waiting for another thread to complete")
662        #    logger.info("\n".join(traceback.format_stack()))
663
664        with calculation_lock:
665            return self._calculate_Iq(qx, qy)
666
667    def _calculate_Iq(self, qx, qy=None):
668        #core.HAVE_OPENCL = False
669        if self._model is None:
670            self._model = core.build_model(self._model_info)
671        if qy is not None:
672            q_vectors = [np.asarray(qx), np.asarray(qy)]
673        else:
674            q_vectors = [np.asarray(qx)]
675        calculator = self._model.make_kernel(q_vectors)
676        parameters = self._model_info.parameters
677        pairs = [self._get_weights(p) for p in parameters.call_parameters]
678        #weights.plot_weights(self._model_info, pairs)
679        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
680        #call_details.show()
681        #print("pairs", pairs)
682        #print("params", self.params)
683        #print("values", values)
684        #print("is_mag", is_magnetic)
685        result = calculator(call_details, values, cutoff=self.cutoff,
686                            magnetic=is_magnetic)
687        self._intermediate_results = getattr(calculator, 'results', None)
688        calculator.release()
689        self._model.release()
690        return result
691
692    def calculate_ER(self):
693        # type: () -> float
694        """
695        Calculate the effective radius for P(q)*S(q)
696
697        :return: the value of the effective radius
698        """
699        if self._model_info.ER is None:
700            return 1.0
701        else:
702            value, weight = self._dispersion_mesh()
703            fv = self._model_info.ER(*value)
704            #print(values[0].shape, weights.shape, fv.shape)
705            return np.sum(weight * fv) / np.sum(weight)
706
707    def calculate_VR(self):
708        # type: () -> float
709        """
710        Calculate the volf ratio for P(q)*S(q)
711
712        :return: the value of the volf ratio
713        """
714        if self._model_info.VR is None:
715            return 1.0
716        else:
717            value, weight = self._dispersion_mesh()
718            whole, part = self._model_info.VR(*value)
719            return np.sum(weight * part) / np.sum(weight * whole)
720
721    def set_dispersion(self, parameter, dispersion):
722        # type: (str, weights.Dispersion) -> Dict[str, Any]
723        """
724        Set the dispersion object for a model parameter
725
726        :param parameter: name of the parameter [string]
727        :param dispersion: dispersion object of type Dispersion
728        """
729        if parameter in self.params:
730            # TODO: Store the disperser object directly in the model.
731            # The current method of relying on the sasview GUI to
732            # remember them is kind of funky.
733            # Note: can't seem to get disperser parameters from sasview
734            # (1) Could create a sasview model that has not yet been
735            # converted, assign the disperser to one of its polydisperse
736            # parameters, then retrieve the disperser parameters from the
737            # sasview model.
738            # (2) Could write a disperser parameter retriever in sasview.
739            # (3) Could modify sasview to use sasmodels.weights dispersers.
740            # For now, rely on the fact that the sasview only ever uses
741            # new dispersers in the set_dispersion call and create a new
742            # one instead of trying to assign parameters.
743            self.dispersion[parameter] = dispersion.get_pars()
744        else:
745            raise ValueError("%r is not a dispersity or orientation parameter")
746
747    def _dispersion_mesh(self):
748        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
749        """
750        Create a mesh grid of dispersion parameters and weights.
751
752        Returns [p1,p2,...],w where pj is a vector of values for parameter j
753        and w is a vector containing the products for weights for each
754        parameter set in the vector.
755        """
756        pars = [self._get_weights(p)
757                for p in self._model_info.parameters.call_parameters
758                if p.type == 'volume']
759        return dispersion_mesh(self._model_info, pars)
760
761    def _get_weights(self, par):
762        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
763        """
764        Return dispersion weights for parameter
765        """
766        if par.name not in self.params:
767            if par.name == self.multiplicity_info.control:
768                return [self.multiplicity], [1.0]
769            else:
770                # For hidden parameters use the default value.
771                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
772                return [value], [1.0]
773        elif par.polydisperse:
774            dis = self.dispersion[par.name]
775            if dis['type'] == 'array':
776                value, weight = dis['values'], dis['weights']
777            else:
778                value, weight = weights.get_weights(
779                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
780                    self.params[par.name], par.limits, par.relative_pd)
781            return value, weight / np.sum(weight)
782        else:
783            return [self.params[par.name]], [1.0]
784
785def test_cylinder():
786    # type: () -> float
787    """
788    Test that the cylinder model runs, returning the value at [0.1,0.1].
789    """
790    Cylinder = _make_standard_model('cylinder')
791    cylinder = Cylinder()
792    return cylinder.evalDistribution([0.1, 0.1])
793
794def test_structure_factor():
795    # type: () -> float
796    """
797    Test that 2-D hardsphere model runs and doesn't produce NaN.
798    """
799    Model = _make_standard_model('hardsphere')
800    model = Model()
801    value = model.evalDistribution([0.1, 0.1])
802    if np.isnan(value):
803        raise ValueError("hardsphere returns null")
804
805def test_rpa():
806    # type: () -> float
807    """
808    Test that the 2-D RPA model runs
809    """
810    RPA = _make_standard_model('rpa')
811    rpa = RPA(3)
812    return rpa.evalDistribution([0.1, 0.1])
813
814def test_empty_distribution():
815    # type: () -> None
816    """
817    Make sure that sasmodels returns NaN when there are no polydispersity points
818    """
819    Cylinder = _make_standard_model('cylinder')
820    cylinder = Cylinder()
821    cylinder.setParam('radius', -1.0)
822    cylinder.setParam('background', 0.)
823    Iq = cylinder.evalDistribution(np.asarray([0.1]))
824    assert np.isnan(Iq[0]), "empty distribution fails"
825
826def test_model_list():
827    # type: () -> None
828    """
829    Make sure that all models build as sasview models
830    """
831    from .exception import annotate_exception
832    for name in core.list_models():
833        try:
834            _make_standard_model(name)
835        except:
836            annotate_exception("when loading "+name)
837            raise
838
839def test_old_name():
840    # type: () -> None
841    """
842    Load and run cylinder model from sas.models.CylinderModel
843    """
844    if not SUPPORT_OLD_STYLE_PLUGINS:
845        return
846    try:
847        # if sasview is not on the path then don't try to test it
848        import sas
849    except ImportError:
850        return
851    load_standard_models()
852    from sas.models.CylinderModel import CylinderModel
853    CylinderModel().evalDistribution([0.1, 0.1])
854
855if __name__ == "__main__":
856    print("cylinder(0.1,0.1)=%g"%test_cylinder())
857    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.