source: sasmodels/sasmodels/sasview_model.py @ 749a7d4

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 749a7d4 was 749a7d4, checked in by Paul Kienzle <pkienzle@…>, 3 years ago

test that model calculation returns NaN if point is out of range

  • Property mode set to 100644
File size: 28.0 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext
18import thread
19
20import numpy as np  # type: ignore
21
22from . import core
23from . import custom
24from . import product
25from . import generate
26from . import weights
27from . import modelinfo
28from .details import make_kernel_args, dispersion_mesh
29
30try:
31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
38    SasviewModelType = Callable[[int], "SasviewModel"]
39except ImportError:
40    pass
41
42calculation_lock = thread.allocate_lock()
43
44SUPPORT_OLD_STYLE_PLUGINS = True
45
46def _register_old_models():
47    # type: () -> None
48    """
49    Place the new models into sasview under the old names.
50
51    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
52    is available to the plugin modules.
53    """
54    import sys
55    import sas   # needed in order to set sas.models
56    import sas.sascalc.fit
57    sys.modules['sas.models'] = sas.sascalc.fit
58    sas.models = sas.sascalc.fit
59
60    import sas.models
61    from sasmodels.conversion_table import CONVERSION_TABLE
62    for new_name, conversion in CONVERSION_TABLE.get((3,1,2), {}).items():
63        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
64        new_name = new_name.split(':')[0]
65        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
66        module_attrs = {old_name: find_model(new_name)}
67        ConstructedModule = type(old_name, (), module_attrs)
68        old_path = 'sas.models.' + old_name
69        setattr(sas.models, old_path, ConstructedModule)
70        sys.modules[old_path] = ConstructedModule
71
72
73# TODO: separate x_axis_label from multiplicity info
74MultiplicityInfo = collections.namedtuple(
75    'MultiplicityInfo',
76    ["number", "control", "choices", "x_axis_label"],
77)
78
79MODELS = {}
80def find_model(modelname):
81    # type: (str) -> SasviewModelType
82    """
83    Find a model by name.  If the model name ends in py, try loading it from
84    custom models, otherwise look for it in the list of builtin models.
85    """
86    # TODO: used by sum/product model to load an existing model
87    # TODO: doesn't handle custom models properly
88    if modelname.endswith('.py'):
89        return load_custom_model(modelname)
90    elif modelname in MODELS:
91        return MODELS[modelname]
92    else:
93        raise ValueError("unknown model %r"%modelname)
94
95
96# TODO: figure out how to say that the return type is a subclass
97def load_standard_models():
98    # type: () -> List[SasviewModelType]
99    """
100    Load and return the list of predefined models.
101
102    If there is an error loading a model, then a traceback is logged and the
103    model is not returned.
104    """
105    models = []
106    for name in core.list_models():
107        try:
108            MODELS[name] = _make_standard_model(name)
109            models.append(MODELS[name])
110        except Exception:
111            logging.error(traceback.format_exc())
112    if SUPPORT_OLD_STYLE_PLUGINS:
113        _register_old_models()
114
115    return models
116
117
118def load_custom_model(path):
119    # type: (str) -> SasviewModelType
120    """
121    Load a custom model given the model path.
122    """
123    kernel_module = custom.load_custom_kernel_module(path)
124    try:
125        model = kernel_module.Model
126        # Old style models do not set the name in the class attributes, so
127        # set it here; this name will be overridden when the object is created
128        # with an instance variable that has the same value.
129        if model.name == "":
130            model.name = splitext(basename(path))[0]
131        if not hasattr(model, 'filename'):
132            model.filename = kernel_module.__file__
133            # For old models, treat .pyc and .py files interchangeably.
134            # This is needed because of the Sum|Multi(p1,p2) types of models
135            # and the convoluted way in which they are created.
136            if model.filename.endswith(".py"):
137                logging.info("Loading %s as .pyc", model.filename)
138                model.filename = model.filename+'c'
139        if not hasattr(model, 'id'):
140            model.id = splitext(basename(model.filename))[0]
141    except AttributeError:
142        model_info = modelinfo.make_model_info(kernel_module)
143        model = _make_model_from_info(model_info)
144
145    # If a model name already exists and we are loading a different model,
146    # use the model file name as the model name.
147    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
148        _previous_name = model.name
149        model.name = model.id
150       
151        # If the new model name is still in the model list (for instance,
152        # if we put a cylinder.py in our plug-in directory), then append
153        # an identifier.
154        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
155            model.name = model.id + '_user'
156        logging.info("Model %s already exists: using %s [%s]", _previous_name, model.name, model.filename)
157
158    MODELS[model.name] = model
159    return model
160
161
162def _make_standard_model(name):
163    # type: (str) -> SasviewModelType
164    """
165    Load the sasview model defined by *name*.
166
167    *name* can be a standard model name or a path to a custom model.
168
169    Returns a class that can be used directly as a sasview model.
170    """
171    kernel_module = generate.load_kernel_module(name)
172    model_info = modelinfo.make_model_info(kernel_module)
173    return _make_model_from_info(model_info)
174
175
176def MultiplicationModel(form_factor, structure_factor):
177    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
178    model_info = product.make_product_info(form_factor._model_info,
179                                           structure_factor._model_info)
180    ConstructedModel = _make_model_from_info(model_info)
181    return ConstructedModel()
182
183def _make_model_from_info(model_info):
184    # type: (ModelInfo) -> SasviewModelType
185    """
186    Convert *model_info* into a SasView model wrapper.
187    """
188    def __init__(self, multiplicity=None):
189        SasviewModel.__init__(self, multiplicity=multiplicity)
190    attrs = _generate_model_attributes(model_info)
191    attrs['__init__'] = __init__
192    attrs['filename'] = model_info.filename
193    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
194    return ConstructedModel
195
196def _generate_model_attributes(model_info):
197    # type: (ModelInfo) -> Dict[str, Any]
198    """
199    Generate the class attributes for the model.
200
201    This should include all the information necessary to query the model
202    details so that you do not need to instantiate a model to query it.
203
204    All the attributes should be immutable to avoid accidents.
205    """
206
207    # TODO: allow model to override axis labels input/output name/unit
208
209    # Process multiplicity
210    non_fittable = []  # type: List[str]
211    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
212    variants = MultiplicityInfo(0, "", [], xlabel)
213    for p in model_info.parameters.kernel_parameters:
214        if p.name == model_info.control:
215            non_fittable.append(p.name)
216            variants = MultiplicityInfo(
217                len(p.choices) if p.choices else int(p.limits[1]),
218                p.name, p.choices, xlabel
219            )
220            break
221
222    # Only a single drop-down list parameter available
223    fun_list = []
224    for p in model_info.parameters.kernel_parameters:
225        if p.choices:
226            fun_list = p.choices
227            if p.length > 1:
228                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
229            break
230
231    # Organize parameter sets
232    orientation_params = []
233    magnetic_params = []
234    fixed = []
235    for p in model_info.parameters.user_parameters({}, is2d=True):
236        if p.type == 'orientation':
237            orientation_params.append(p.name)
238            orientation_params.append(p.name+".width")
239            fixed.append(p.name+".width")
240        elif p.type == 'magnetic':
241            orientation_params.append(p.name)
242            magnetic_params.append(p.name)
243            fixed.append(p.name+".width")
244
245
246    # Build class dictionary
247    attrs = {}  # type: Dict[str, Any]
248    attrs['_model_info'] = model_info
249    attrs['name'] = model_info.name
250    attrs['id'] = model_info.id
251    attrs['description'] = model_info.description
252    attrs['category'] = model_info.category
253    attrs['is_structure_factor'] = model_info.structure_factor
254    attrs['is_form_factor'] = model_info.ER is not None
255    attrs['is_multiplicity_model'] = variants[0] > 1
256    attrs['multiplicity_info'] = variants
257    attrs['orientation_params'] = tuple(orientation_params)
258    attrs['magnetic_params'] = tuple(magnetic_params)
259    attrs['fixed'] = tuple(fixed)
260    attrs['non_fittable'] = tuple(non_fittable)
261    attrs['fun_list'] = tuple(fun_list)
262
263    return attrs
264
265class SasviewModel(object):
266    """
267    Sasview wrapper for opencl/ctypes model.
268    """
269    # Model parameters for the specific model are set in the class constructor
270    # via the _generate_model_attributes function, which subclasses
271    # SasviewModel.  They are included here for typing and documentation
272    # purposes.
273    _model = None       # type: KernelModel
274    _model_info = None  # type: ModelInfo
275    #: load/save name for the model
276    id = None           # type: str
277    #: display name for the model
278    name = None         # type: str
279    #: short model description
280    description = None  # type: str
281    #: default model category
282    category = None     # type: str
283
284    #: names of the orientation parameters in the order they appear
285    orientation_params = None # type: Sequence[str]
286    #: names of the magnetic parameters in the order they appear
287    magnetic_params = None    # type: Sequence[str]
288    #: names of the fittable parameters
289    fixed = None              # type: Sequence[str]
290    # TODO: the attribute fixed is ill-named
291
292    # Axis labels
293    input_name = "Q"
294    input_unit = "A^{-1}"
295    output_name = "Intensity"
296    output_unit = "cm^{-1}"
297
298    #: default cutoff for polydispersity
299    cutoff = 1e-5
300
301    # Note: Use non-mutable values for class attributes to avoid errors
302    #: parameters that are not fitted
303    non_fittable = ()        # type: Sequence[str]
304
305    #: True if model should appear as a structure factor
306    is_structure_factor = False
307    #: True if model should appear as a form factor
308    is_form_factor = False
309    #: True if model has multiplicity
310    is_multiplicity_model = False
311    #: Mulitplicity information
312    multiplicity_info = None # type: MultiplicityInfoType
313
314    # Per-instance variables
315    #: parameter {name: value} mapping
316    params = None      # type: Dict[str, float]
317    #: values for dispersion width, npts, nsigmas and type
318    dispersion = None  # type: Dict[str, Any]
319    #: units and limits for each parameter
320    details = None     # type: Dict[str, Sequence[Any]]
321    #                  # actual type is Dict[str, List[str, float, float]]
322    #: multiplicity value, or None if no multiplicity on the model
323    multiplicity = None     # type: Optional[int]
324    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
325    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
326
327    def __init__(self, multiplicity=None):
328        # type: (Optional[int]) -> None
329
330        # TODO: _persistency_dict to persistency_dict throughout sasview
331        # TODO: refactor multiplicity to encompass variants
332        # TODO: dispersion should be a class
333        # TODO: refactor multiplicity info
334        # TODO: separate profile view from multiplicity
335        # The button label, x and y axis labels and scale need to be under
336        # the control of the model, not the fit page.  Maximum flexibility,
337        # the fit page would supply the canvas and the profile could plot
338        # how it wants, but this assumes matplotlib.  Next level is that
339        # we provide some sort of data description including title, labels
340        # and lines to plot.
341
342        # Get the list of hidden parameters given the mulitplicity
343        # Don't include multiplicity in the list of parameters
344        self.multiplicity = multiplicity
345        if multiplicity is not None:
346            hidden = self._model_info.get_hidden_parameters(multiplicity)
347            hidden |= set([self.multiplicity_info.control])
348        else:
349            hidden = set()
350        if self._model_info.structure_factor:
351            hidden.add('scale')
352            hidden.add('background')
353            self._model_info.parameters.defaults['background'] = 0.
354
355        self._persistency_dict = {}
356        self.params = collections.OrderedDict()
357        self.dispersion = collections.OrderedDict()
358        self.details = {}
359        for p in self._model_info.parameters.user_parameters({}, is2d=True):
360            if p.name in hidden:
361                continue
362            self.params[p.name] = p.default
363            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
364            if p.polydisperse:
365                self.details[p.id+".width"] = [
366                    "", 0.0, 1.0 if p.relative_pd else np.inf
367                ]
368                self.dispersion[p.name] = {
369                    'width': 0,
370                    'npts': 35,
371                    'nsigmas': 3,
372                    'type': 'gaussian',
373                }
374
375    def __get_state__(self):
376        # type: () -> Dict[str, Any]
377        state = self.__dict__.copy()
378        state.pop('_model')
379        # May need to reload model info on set state since it has pointers
380        # to python implementations of Iq, etc.
381        #state.pop('_model_info')
382        return state
383
384    def __set_state__(self, state):
385        # type: (Dict[str, Any]) -> None
386        self.__dict__ = state
387        self._model = None
388
389    def __str__(self):
390        # type: () -> str
391        """
392        :return: string representation
393        """
394        return self.name
395
396    def is_fittable(self, par_name):
397        # type: (str) -> bool
398        """
399        Check if a given parameter is fittable or not
400
401        :param par_name: the parameter name to check
402        """
403        return par_name in self.fixed
404        #For the future
405        #return self.params[str(par_name)].is_fittable()
406
407
408    def getProfile(self):
409        # type: () -> (np.ndarray, np.ndarray)
410        """
411        Get SLD profile
412
413        : return: (z, beta) where z is a list of depth of the transition points
414                beta is a list of the corresponding SLD values
415        """
416        args = {} # type: Dict[str, Any]
417        for p in self._model_info.parameters.kernel_parameters:
418            if p.id == self.multiplicity_info.control:
419                value = float(self.multiplicity)
420            elif p.length == 1:
421                value = self.params.get(p.id, np.NaN)
422            else:
423                value = np.array([self.params.get(p.id+str(k), np.NaN)
424                                  for k in range(1, p.length+1)])
425            args[p.id] = value
426
427        x, y = self._model_info.profile(**args)
428        return x, 1e-6*y
429
430    def setParam(self, name, value):
431        # type: (str, float) -> None
432        """
433        Set the value of a model parameter
434
435        :param name: name of the parameter
436        :param value: value of the parameter
437
438        """
439        # Look for dispersion parameters
440        toks = name.split('.')
441        if len(toks) == 2:
442            for item in self.dispersion.keys():
443                if item == toks[0]:
444                    for par in self.dispersion[item]:
445                        if par == toks[1]:
446                            self.dispersion[item][par] = value
447                            return
448        else:
449            # Look for standard parameter
450            for item in self.params.keys():
451                if item == name:
452                    self.params[item] = value
453                    return
454
455        raise ValueError("Model does not contain parameter %s" % name)
456
457    def getParam(self, name):
458        # type: (str) -> float
459        """
460        Set the value of a model parameter
461
462        :param name: name of the parameter
463
464        """
465        # Look for dispersion parameters
466        toks = name.split('.')
467        if len(toks) == 2:
468            for item in self.dispersion.keys():
469                if item == toks[0]:
470                    for par in self.dispersion[item]:
471                        if par == toks[1]:
472                            return self.dispersion[item][par]
473        else:
474            # Look for standard parameter
475            for item in self.params.keys():
476                if item == name:
477                    return self.params[item]
478
479        raise ValueError("Model does not contain parameter %s" % name)
480
481    def getParamList(self):
482        # type: () -> Sequence[str]
483        """
484        Return a list of all available parameters for the model
485        """
486        param_list = list(self.params.keys())
487        # WARNING: Extending the list with the dispersion parameters
488        param_list.extend(self.getDispParamList())
489        return param_list
490
491    def getDispParamList(self):
492        # type: () -> Sequence[str]
493        """
494        Return a list of polydispersity parameters for the model
495        """
496        # TODO: fix test so that parameter order doesn't matter
497        ret = ['%s.%s' % (p_name, ext)
498               for p_name in self.dispersion.keys()
499               for ext in ('npts', 'nsigmas', 'width')]
500        #print(ret)
501        return ret
502
503    def clone(self):
504        # type: () -> "SasviewModel"
505        """ Return a identical copy of self """
506        return deepcopy(self)
507
508    def run(self, x=0.0):
509        # type: (Union[float, (float, float), List[float]]) -> float
510        """
511        Evaluate the model
512
513        :param x: input q, or [q,phi]
514
515        :return: scattering function P(q)
516
517        **DEPRECATED**: use calculate_Iq instead
518        """
519        if isinstance(x, (list, tuple)):
520            # pylint: disable=unpacking-non-sequence
521            q, phi = x
522            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
523        else:
524            return self.calculate_Iq([x])[0]
525
526
527    def runXY(self, x=0.0):
528        # type: (Union[float, (float, float), List[float]]) -> float
529        """
530        Evaluate the model in cartesian coordinates
531
532        :param x: input q, or [qx, qy]
533
534        :return: scattering function P(q)
535
536        **DEPRECATED**: use calculate_Iq instead
537        """
538        if isinstance(x, (list, tuple)):
539            return self.calculate_Iq([x[0]], [x[1]])[0]
540        else:
541            return self.calculate_Iq([x])[0]
542
543    def evalDistribution(self, qdist):
544        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
545        r"""
546        Evaluate a distribution of q-values.
547
548        :param qdist: array of q or a list of arrays [qx,qy]
549
550        * For 1D, a numpy array is expected as input
551
552        ::
553
554            evalDistribution(q)
555
556          where *q* is a numpy array.
557
558        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
559
560        ::
561
562              qx = [ qx[0], qx[1], qx[2], ....]
563              qy = [ qy[0], qy[1], qy[2], ....]
564
565        If the model is 1D only, then
566
567        .. math::
568
569            q = \sqrt{q_x^2+q_y^2}
570
571        """
572        if isinstance(qdist, (list, tuple)):
573            # Check whether we have a list of ndarrays [qx,qy]
574            qx, qy = qdist
575            if not self._model_info.parameters.has_2d:
576                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
577            else:
578                return self.calculate_Iq(qx, qy)
579
580        elif isinstance(qdist, np.ndarray):
581            # We have a simple 1D distribution of q-values
582            return self.calculate_Iq(qdist)
583
584        else:
585            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
586                            % type(qdist))
587
588    def get_composition_models(self):
589        """
590            Returns usable models that compose this model
591        """
592        s_model = None
593        p_model = None
594        if hasattr(self._model_info, "composition") \
595           and self._model_info.composition is not None:
596            p_model = _make_model_from_info(self._model_info.composition[1][0])()
597            s_model = _make_model_from_info(self._model_info.composition[1][1])()
598        return p_model, s_model
599
600    def calculate_Iq(self, qx, qy=None):
601        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
602        """
603        Calculate Iq for one set of q with the current parameters.
604
605        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
606
607        This should NOT be used for fitting since it copies the *q* vectors
608        to the card for each evaluation.
609        """
610        ## uncomment the following when trying to debug the uncoordinated calls
611        ## to calculate_Iq
612        #if calculation_lock.locked():
613        #    logging.info("calculation waiting for another thread to complete")
614        #    logging.info("\n".join(traceback.format_stack()))
615
616        with calculation_lock:
617            return self._calculate_Iq(qx, qy)
618
619    def _calculate_Iq(self, qx, qy=None):
620        #core.HAVE_OPENCL = False
621        if self._model is None:
622            self._model = core.build_model(self._model_info)
623        if qy is not None:
624            q_vectors = [np.asarray(qx), np.asarray(qy)]
625        else:
626            q_vectors = [np.asarray(qx)]
627        calculator = self._model.make_kernel(q_vectors)
628        parameters = self._model_info.parameters
629        pairs = [self._get_weights(p) for p in parameters.call_parameters]
630        #weights.plot_weights(self._model_info, pairs)
631        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
632        #call_details.show()
633        #print("pairs", pairs)
634        #print("params", self.params)
635        #print("values", values)
636        #print("is_mag", is_magnetic)
637        result = calculator(call_details, values, cutoff=self.cutoff,
638                            magnetic=is_magnetic)
639        calculator.release()
640        self._model.release()
641        return result
642
643    def calculate_ER(self):
644        # type: () -> float
645        """
646        Calculate the effective radius for P(q)*S(q)
647
648        :return: the value of the effective radius
649        """
650        if self._model_info.ER is None:
651            return 1.0
652        else:
653            value, weight = self._dispersion_mesh()
654            fv = self._model_info.ER(*value)
655            #print(values[0].shape, weights.shape, fv.shape)
656            return np.sum(weight * fv) / np.sum(weight)
657
658    def calculate_VR(self):
659        # type: () -> float
660        """
661        Calculate the volf ratio for P(q)*S(q)
662
663        :return: the value of the volf ratio
664        """
665        if self._model_info.VR is None:
666            return 1.0
667        else:
668            value, weight = self._dispersion_mesh()
669            whole, part = self._model_info.VR(*value)
670            return np.sum(weight * part) / np.sum(weight * whole)
671
672    def set_dispersion(self, parameter, dispersion):
673        # type: (str, weights.Dispersion) -> Dict[str, Any]
674        """
675        Set the dispersion object for a model parameter
676
677        :param parameter: name of the parameter [string]
678        :param dispersion: dispersion object of type Dispersion
679        """
680        if parameter in self.params:
681            # TODO: Store the disperser object directly in the model.
682            # The current method of relying on the sasview GUI to
683            # remember them is kind of funky.
684            # Note: can't seem to get disperser parameters from sasview
685            # (1) Could create a sasview model that has not yet been
686            # converted, assign the disperser to one of its polydisperse
687            # parameters, then retrieve the disperser parameters from the
688            # sasview model.
689            # (2) Could write a disperser parameter retriever in sasview.
690            # (3) Could modify sasview to use sasmodels.weights dispersers.
691            # For now, rely on the fact that the sasview only ever uses
692            # new dispersers in the set_dispersion call and create a new
693            # one instead of trying to assign parameters.
694            self.dispersion[parameter] = dispersion.get_pars()
695        else:
696            raise ValueError("%r is not a dispersity or orientation parameter")
697
698    def _dispersion_mesh(self):
699        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
700        """
701        Create a mesh grid of dispersion parameters and weights.
702
703        Returns [p1,p2,...],w where pj is a vector of values for parameter j
704        and w is a vector containing the products for weights for each
705        parameter set in the vector.
706        """
707        pars = [self._get_weights(p)
708                for p in self._model_info.parameters.call_parameters
709                if p.type == 'volume']
710        return dispersion_mesh(self._model_info, pars)
711
712    def _get_weights(self, par):
713        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
714        """
715        Return dispersion weights for parameter
716        """
717        if par.name not in self.params:
718            if par.name == self.multiplicity_info.control:
719                return [self.multiplicity], [1.0]
720            else:
721                # For hidden parameters use the default value.
722                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
723                return [value], [1.0]
724        elif par.polydisperse:
725            dis = self.dispersion[par.name]
726            if dis['type'] == 'array':
727                value, weight = dis['values'], dis['weights']
728            else:
729                value, weight = weights.get_weights(
730                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
731                    self.params[par.name], par.limits, par.relative_pd)
732            return value, weight / np.sum(weight)
733        else:
734            return [self.params[par.name]], [1.0]
735
736def test_cylinder():
737    # type: () -> float
738    """
739    Test that the cylinder model runs, returning the value at [0.1,0.1].
740    """
741    Cylinder = _make_standard_model('cylinder')
742    cylinder = Cylinder()
743    return cylinder.evalDistribution([0.1, 0.1])
744
745def test_structure_factor():
746    # type: () -> float
747    """
748    Test that 2-D hardsphere model runs and doesn't produce NaN.
749    """
750    Model = _make_standard_model('hardsphere')
751    model = Model()
752    value = model.evalDistribution([0.1, 0.1])
753    if np.isnan(value):
754        raise ValueError("hardsphere returns null")
755
756def test_rpa():
757    # type: () -> float
758    """
759    Test that the 2-D RPA model runs
760    """
761    RPA = _make_standard_model('rpa')
762    rpa = RPA(3)
763    return rpa.evalDistribution([0.1, 0.1])
764
765def test_empty_distribution():
766    # type: () -> None
767    """
768    Make sure that sasmodels returns NaN when there are no polydispersity points
769    """
770    Cylinder = _make_standard_model('cylinder')
771    cylinder = Cylinder()
772    cylinder.setParam('radius', -1.0)
773    cylinder.setParam('background', 0.)
774    Iq = cylinder.evalDistribution(np.asarray([0.1]))
775    assert np.isnan(Iq[0]), "empty distribution fails"
776
777def test_model_list():
778    # type: () -> None
779    """
780    Make sure that all models build as sasview models
781    """
782    from .exception import annotate_exception
783    for name in core.list_models():
784        try:
785            _make_standard_model(name)
786        except:
787            annotate_exception("when loading "+name)
788            raise
789
790def test_old_name():
791    # type: () -> None
792    """
793    Load and run cylinder model from sas.models.CylinderModel
794    """
795    if not SUPPORT_OLD_STYLE_PLUGINS:
796        return
797    try:
798        # if sasview is not on the path then don't try to test it
799        import sas
800    except ImportError:
801        return
802    load_standard_models()
803    from sas.models.CylinderModel import CylinderModel
804    CylinderModel().evalDistribution([0.1, 0.1])
805
806if __name__ == "__main__":
807    print("cylinder(0.1,0.1)=%g"%test_cylinder())
808    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.