source: sasmodels/sasmodels/sasview_model.py @ 946c8d27

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 946c8d27 was 946c8d27, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

add a TODO to the calc_composition_models code

  • Property mode set to 100644
File size: 29.5 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18import thread
19
20import numpy as np  # type: ignore
21
22from . import core
23from . import custom
24from . import product
25from . import generate
26from . import weights
27from . import modelinfo
28from .details import make_kernel_args, dispersion_mesh
29
30try:
31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
38    SasviewModelType = Callable[[int], "SasviewModel"]
39except ImportError:
40    pass
41
42logger = logging.getLogger(__name__)
43
44calculation_lock = thread.allocate_lock()
45
46#: True if pre-existing plugins, with the old names and parameters, should
47#: continue to be supported.
48SUPPORT_OLD_STYLE_PLUGINS = True
49
50# TODO: separate x_axis_label from multiplicity info
51MultiplicityInfo = collections.namedtuple(
52    'MultiplicityInfo',
53    ["number", "control", "choices", "x_axis_label"],
54)
55
56#: set of defined models (standard and custom)
57MODELS = {}  # type: Dict[str, SasviewModelType]
58#: custom model {path: model} mapping so we can check timestamps
59MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
60
61def find_model(modelname):
62    # type: (str) -> SasviewModelType
63    """
64    Find a model by name.  If the model name ends in py, try loading it from
65    custom models, otherwise look for it in the list of builtin models.
66    """
67    # TODO: used by sum/product model to load an existing model
68    # TODO: doesn't handle custom models properly
69    if modelname.endswith('.py'):
70        return load_custom_model(modelname)
71    elif modelname in MODELS:
72        return MODELS[modelname]
73    else:
74        raise ValueError("unknown model %r"%modelname)
75
76
77# TODO: figure out how to say that the return type is a subclass
78def load_standard_models():
79    # type: () -> List[SasviewModelType]
80    """
81    Load and return the list of predefined models.
82
83    If there is an error loading a model, then a traceback is logged and the
84    model is not returned.
85    """
86    for name in core.list_models():
87        try:
88            MODELS[name] = _make_standard_model(name)
89        except Exception:
90            logger.error(traceback.format_exc())
91    if SUPPORT_OLD_STYLE_PLUGINS:
92        _register_old_models()
93
94    return list(MODELS.values())
95
96
97def load_custom_model(path):
98    # type: (str) -> SasviewModelType
99    """
100    Load a custom model given the model path.
101    """
102    model = MODEL_BY_PATH.get(path, None)
103    if model is not None and model.timestamp == getmtime(path):
104        #logger.info("Model already loaded %s", path)
105        return model
106
107    #logger.info("Loading model %s", path)
108    kernel_module = custom.load_custom_kernel_module(path)
109    if hasattr(kernel_module, 'Model'):
110        model = kernel_module.Model
111        # Old style models do not set the name in the class attributes, so
112        # set it here; this name will be overridden when the object is created
113        # with an instance variable that has the same value.
114        if model.name == "":
115            model.name = splitext(basename(path))[0]
116        if not hasattr(model, 'filename'):
117            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
118        if not hasattr(model, 'id'):
119            model.id = splitext(basename(model.filename))[0]
120    else:
121        model_info = modelinfo.make_model_info(kernel_module)
122        model = make_model_from_info(model_info)
123    model.timestamp = getmtime(path)
124
125    # If a model name already exists and we are loading a different model,
126    # use the model file name as the model name.
127    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
128        _previous_name = model.name
129        model.name = model.id
130
131        # If the new model name is still in the model list (for instance,
132        # if we put a cylinder.py in our plug-in directory), then append
133        # an identifier.
134        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
135            model.name = model.id + '_user'
136        logger.info("Model %s already exists: using %s [%s]",
137                    _previous_name, model.name, model.filename)
138
139    MODELS[model.name] = model
140    MODEL_BY_PATH[path] = model
141    return model
142
143
144def make_model_from_info(model_info):
145    # type: (ModelInfo) -> SasviewModelType
146    """
147    Convert *model_info* into a SasView model wrapper.
148    """
149    def __init__(self, multiplicity=None):
150        SasviewModel.__init__(self, multiplicity=multiplicity)
151    attrs = _generate_model_attributes(model_info)
152    attrs['__init__'] = __init__
153    attrs['filename'] = model_info.filename
154    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
155    return ConstructedModel
156
157
158def _make_standard_model(name):
159    # type: (str) -> SasviewModelType
160    """
161    Load the sasview model defined by *name*.
162
163    *name* can be a standard model name or a path to a custom model.
164
165    Returns a class that can be used directly as a sasview model.
166    """
167    kernel_module = generate.load_kernel_module(name)
168    model_info = modelinfo.make_model_info(kernel_module)
169    return make_model_from_info(model_info)
170
171
172def _register_old_models():
173    # type: () -> None
174    """
175    Place the new models into sasview under the old names.
176
177    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
178    is available to the plugin modules.
179    """
180    import sys
181    import sas   # needed in order to set sas.models
182    import sas.sascalc.fit
183    sys.modules['sas.models'] = sas.sascalc.fit
184    sas.models = sas.sascalc.fit
185
186    import sas.models
187    from sasmodels.conversion_table import CONVERSION_TABLE
188    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
189        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
190        new_name = new_name.split(':')[0]
191        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
192        module_attrs = {old_name: find_model(new_name)}
193        ConstructedModule = type(old_name, (), module_attrs)
194        old_path = 'sas.models.' + old_name
195        setattr(sas.models, old_path, ConstructedModule)
196        sys.modules[old_path] = ConstructedModule
197
198
199def MultiplicationModel(form_factor, structure_factor):
200    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
201    model_info = product.make_product_info(form_factor._model_info,
202                                           structure_factor._model_info)
203    ConstructedModel = make_model_from_info(model_info)
204    return ConstructedModel()
205
206
207def _generate_model_attributes(model_info):
208    # type: (ModelInfo) -> Dict[str, Any]
209    """
210    Generate the class attributes for the model.
211
212    This should include all the information necessary to query the model
213    details so that you do not need to instantiate a model to query it.
214
215    All the attributes should be immutable to avoid accidents.
216    """
217
218    # TODO: allow model to override axis labels input/output name/unit
219
220    # Process multiplicity
221    non_fittable = []  # type: List[str]
222    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
223    variants = MultiplicityInfo(0, "", [], xlabel)
224    for p in model_info.parameters.kernel_parameters:
225        if p.name == model_info.control:
226            non_fittable.append(p.name)
227            variants = MultiplicityInfo(
228                len(p.choices) if p.choices else int(p.limits[1]),
229                p.name, p.choices, xlabel
230            )
231            break
232
233    # Only a single drop-down list parameter available
234    fun_list = []
235    for p in model_info.parameters.kernel_parameters:
236        if p.choices:
237            fun_list = p.choices
238            if p.length > 1:
239                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
240            break
241
242    # Organize parameter sets
243    orientation_params = []
244    magnetic_params = []
245    fixed = []
246    for p in model_info.parameters.user_parameters({}, is2d=True):
247        if p.type == 'orientation':
248            orientation_params.append(p.name)
249            orientation_params.append(p.name+".width")
250            fixed.append(p.name+".width")
251        elif p.type == 'magnetic':
252            orientation_params.append(p.name)
253            magnetic_params.append(p.name)
254            fixed.append(p.name+".width")
255
256
257    # Build class dictionary
258    attrs = {}  # type: Dict[str, Any]
259    attrs['_model_info'] = model_info
260    attrs['name'] = model_info.name
261    attrs['id'] = model_info.id
262    attrs['description'] = model_info.description
263    attrs['category'] = model_info.category
264    attrs['is_structure_factor'] = model_info.structure_factor
265    attrs['is_form_factor'] = model_info.ER is not None
266    attrs['is_multiplicity_model'] = variants[0] > 1
267    attrs['multiplicity_info'] = variants
268    attrs['orientation_params'] = tuple(orientation_params)
269    attrs['magnetic_params'] = tuple(magnetic_params)
270    attrs['fixed'] = tuple(fixed)
271    attrs['non_fittable'] = tuple(non_fittable)
272    attrs['fun_list'] = tuple(fun_list)
273
274    return attrs
275
276class SasviewModel(object):
277    """
278    Sasview wrapper for opencl/ctypes model.
279    """
280    # Model parameters for the specific model are set in the class constructor
281    # via the _generate_model_attributes function, which subclasses
282    # SasviewModel.  They are included here for typing and documentation
283    # purposes.
284    _model = None       # type: KernelModel
285    _model_info = None  # type: ModelInfo
286    #: load/save name for the model
287    id = None           # type: str
288    #: display name for the model
289    name = None         # type: str
290    #: short model description
291    description = None  # type: str
292    #: default model category
293    category = None     # type: str
294
295    #: names of the orientation parameters in the order they appear
296    orientation_params = None # type: List[str]
297    #: names of the magnetic parameters in the order they appear
298    magnetic_params = None    # type: List[str]
299    #: names of the fittable parameters
300    fixed = None              # type: List[str]
301    # TODO: the attribute fixed is ill-named
302
303    # Axis labels
304    input_name = "Q"
305    input_unit = "A^{-1}"
306    output_name = "Intensity"
307    output_unit = "cm^{-1}"
308
309    #: default cutoff for polydispersity
310    cutoff = 1e-5
311
312    # Note: Use non-mutable values for class attributes to avoid errors
313    #: parameters that are not fitted
314    non_fittable = ()        # type: Sequence[str]
315
316    #: True if model should appear as a structure factor
317    is_structure_factor = False
318    #: True if model should appear as a form factor
319    is_form_factor = False
320    #: True if model has multiplicity
321    is_multiplicity_model = False
322    #: Mulitplicity information
323    multiplicity_info = None # type: MultiplicityInfoType
324
325    # Per-instance variables
326    #: parameter {name: value} mapping
327    params = None      # type: Dict[str, float]
328    #: values for dispersion width, npts, nsigmas and type
329    dispersion = None  # type: Dict[str, Any]
330    #: units and limits for each parameter
331    details = None     # type: Dict[str, Sequence[Any]]
332    #                  # actual type is Dict[str, List[str, float, float]]
333    #: multiplicity value, or None if no multiplicity on the model
334    multiplicity = None     # type: Optional[int]
335    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
336    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
337
338    def __init__(self, multiplicity=None):
339        # type: (Optional[int]) -> None
340
341        # TODO: _persistency_dict to persistency_dict throughout sasview
342        # TODO: refactor multiplicity to encompass variants
343        # TODO: dispersion should be a class
344        # TODO: refactor multiplicity info
345        # TODO: separate profile view from multiplicity
346        # The button label, x and y axis labels and scale need to be under
347        # the control of the model, not the fit page.  Maximum flexibility,
348        # the fit page would supply the canvas and the profile could plot
349        # how it wants, but this assumes matplotlib.  Next level is that
350        # we provide some sort of data description including title, labels
351        # and lines to plot.
352
353        # Get the list of hidden parameters given the mulitplicity
354        # Don't include multiplicity in the list of parameters
355        self.multiplicity = multiplicity
356        if multiplicity is not None:
357            hidden = self._model_info.get_hidden_parameters(multiplicity)
358            hidden |= set([self.multiplicity_info.control])
359        else:
360            hidden = set()
361        if self._model_info.structure_factor:
362            hidden.add('scale')
363            hidden.add('background')
364            self._model_info.parameters.defaults['background'] = 0.
365
366        self._persistency_dict = {}
367        self.params = collections.OrderedDict()
368        self.dispersion = collections.OrderedDict()
369        self.details = {}
370        for p in self._model_info.parameters.user_parameters({}, is2d=True):
371            if p.name in hidden:
372                continue
373            self.params[p.name] = p.default
374            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
375            if p.polydisperse:
376                self.details[p.id+".width"] = [
377                    "", 0.0, 1.0 if p.relative_pd else np.inf
378                ]
379                self.dispersion[p.name] = {
380                    'width': 0,
381                    'npts': 35,
382                    'nsigmas': 3,
383                    'type': 'gaussian',
384                }
385
386    def __get_state__(self):
387        # type: () -> Dict[str, Any]
388        state = self.__dict__.copy()
389        state.pop('_model')
390        # May need to reload model info on set state since it has pointers
391        # to python implementations of Iq, etc.
392        #state.pop('_model_info')
393        return state
394
395    def __set_state__(self, state):
396        # type: (Dict[str, Any]) -> None
397        self.__dict__ = state
398        self._model = None
399
400    def __str__(self):
401        # type: () -> str
402        """
403        :return: string representation
404        """
405        return self.name
406
407    def is_fittable(self, par_name):
408        # type: (str) -> bool
409        """
410        Check if a given parameter is fittable or not
411
412        :param par_name: the parameter name to check
413        """
414        return par_name in self.fixed
415        #For the future
416        #return self.params[str(par_name)].is_fittable()
417
418
419    def getProfile(self):
420        # type: () -> (np.ndarray, np.ndarray)
421        """
422        Get SLD profile
423
424        : return: (z, beta) where z is a list of depth of the transition points
425                beta is a list of the corresponding SLD values
426        """
427        args = {} # type: Dict[str, Any]
428        for p in self._model_info.parameters.kernel_parameters:
429            if p.id == self.multiplicity_info.control:
430                value = float(self.multiplicity)
431            elif p.length == 1:
432                value = self.params.get(p.id, np.NaN)
433            else:
434                value = np.array([self.params.get(p.id+str(k), np.NaN)
435                                  for k in range(1, p.length+1)])
436            args[p.id] = value
437
438        x, y = self._model_info.profile(**args)
439        return x, 1e-6*y
440
441    def setParam(self, name, value):
442        # type: (str, float) -> None
443        """
444        Set the value of a model parameter
445
446        :param name: name of the parameter
447        :param value: value of the parameter
448
449        """
450        # Look for dispersion parameters
451        toks = name.split('.')
452        if len(toks) == 2:
453            for item in self.dispersion.keys():
454                if item == toks[0]:
455                    for par in self.dispersion[item]:
456                        if par == toks[1]:
457                            self.dispersion[item][par] = value
458                            return
459        else:
460            # Look for standard parameter
461            for item in self.params.keys():
462                if item == name:
463                    self.params[item] = value
464                    return
465
466        raise ValueError("Model does not contain parameter %s" % name)
467
468    def getParam(self, name):
469        # type: (str) -> float
470        """
471        Set the value of a model parameter
472
473        :param name: name of the parameter
474
475        """
476        # Look for dispersion parameters
477        toks = name.split('.')
478        if len(toks) == 2:
479            for item in self.dispersion.keys():
480                if item == toks[0]:
481                    for par in self.dispersion[item]:
482                        if par == toks[1]:
483                            return self.dispersion[item][par]
484        else:
485            # Look for standard parameter
486            for item in self.params.keys():
487                if item == name:
488                    return self.params[item]
489
490        raise ValueError("Model does not contain parameter %s" % name)
491
492    def getParamList(self):
493        # type: () -> Sequence[str]
494        """
495        Return a list of all available parameters for the model
496        """
497        param_list = list(self.params.keys())
498        # WARNING: Extending the list with the dispersion parameters
499        param_list.extend(self.getDispParamList())
500        return param_list
501
502    def getDispParamList(self):
503        # type: () -> Sequence[str]
504        """
505        Return a list of polydispersity parameters for the model
506        """
507        # TODO: fix test so that parameter order doesn't matter
508        ret = ['%s.%s' % (p_name, ext)
509               for p_name in self.dispersion.keys()
510               for ext in ('npts', 'nsigmas', 'width')]
511        #print(ret)
512        return ret
513
514    def clone(self):
515        # type: () -> "SasviewModel"
516        """ Return a identical copy of self """
517        return deepcopy(self)
518
519    def run(self, x=0.0):
520        # type: (Union[float, (float, float), List[float]]) -> float
521        """
522        Evaluate the model
523
524        :param x: input q, or [q,phi]
525
526        :return: scattering function P(q)
527
528        **DEPRECATED**: use calculate_Iq instead
529        """
530        if isinstance(x, (list, tuple)):
531            # pylint: disable=unpacking-non-sequence
532            q, phi = x
533            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
534        else:
535            return self.calculate_Iq([x])[0]
536
537
538    def runXY(self, x=0.0):
539        # type: (Union[float, (float, float), List[float]]) -> float
540        """
541        Evaluate the model in cartesian coordinates
542
543        :param x: input q, or [qx, qy]
544
545        :return: scattering function P(q)
546
547        **DEPRECATED**: use calculate_Iq instead
548        """
549        if isinstance(x, (list, tuple)):
550            return self.calculate_Iq([x[0]], [x[1]])[0]
551        else:
552            return self.calculate_Iq([x])[0]
553
554    def evalDistribution(self, qdist):
555        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
556        r"""
557        Evaluate a distribution of q-values.
558
559        :param qdist: array of q or a list of arrays [qx,qy]
560
561        * For 1D, a numpy array is expected as input
562
563        ::
564
565            evalDistribution(q)
566
567          where *q* is a numpy array.
568
569        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
570
571        ::
572
573              qx = [ qx[0], qx[1], qx[2], ....]
574              qy = [ qy[0], qy[1], qy[2], ....]
575
576        If the model is 1D only, then
577
578        .. math::
579
580            q = \sqrt{q_x^2+q_y^2}
581
582        """
583        if isinstance(qdist, (list, tuple)):
584            # Check whether we have a list of ndarrays [qx,qy]
585            qx, qy = qdist
586            if not self._model_info.parameters.has_2d:
587                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
588            else:
589                return self.calculate_Iq(qx, qy)
590
591        elif isinstance(qdist, np.ndarray):
592            # We have a simple 1D distribution of q-values
593            return self.calculate_Iq(qdist)
594
595        else:
596            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
597                            % type(qdist))
598
599    def calc_composition_models(self, qx):
600        """
601        returns parts of the composition model or None if not a composition
602        model.
603        """
604        # TODO: have calculate_Iq return the intermediates.
605        #
606        # The current interface causes calculate_Iq() to be called twice,
607        # once to get the combined result and again to get the intermediate
608        # results.  This is necessary for now.
609        # Long term, the solution is to change the interface to calculate_Iq
610        # so that it returns a results object containing all the bits:
611        #     the A, B, C, ... of the composition model
612        #     the P and S of the product model,
613        #     the combined model before resolution smearing,
614        #     the sasmodel before sesans conversion,
615        #     the oriented 2D model used to fit oriented usans data,
616        #     the final I(q),
617        #     ...
618        # Have the model calculator add all of these blindly to the data
619        # tree, and update the graphs which contain them.  The fitter
620        # needs to be updated to use the I(q) value only, ignoring the rest.
621        #
622        # The simple fix of returning the existing intermediate results
623        # will not work for a couple of reasons: (1) another thread may
624        # sneak in to compute its own results before calc_composition_models
625        # is called, and (2) calculate_Iq is currently called three times:
626        # once with q, once with q values before qmin and once with q values
627        # after q max.  Both of these should be addressed before
628        # replacing this code.
629        with calculation_lock:
630            self._calculate_Iq(qx)
631            return self._intermediate_results
632
633    def calculate_Iq(self, qx, qy=None):
634        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
635        """
636        Calculate Iq for one set of q with the current parameters.
637
638        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
639
640        This should NOT be used for fitting since it copies the *q* vectors
641        to the card for each evaluation.
642        """
643        ## uncomment the following when trying to debug the uncoordinated calls
644        ## to calculate_Iq
645        #if calculation_lock.locked():
646        #    logger.info("calculation waiting for another thread to complete")
647        #    logger.info("\n".join(traceback.format_stack()))
648
649        with calculation_lock:
650            return self._calculate_Iq(qx, qy)
651
652    def _calculate_Iq(self, qx, qy=None):
653        #core.HAVE_OPENCL = False
654        if self._model is None:
655            self._model = core.build_model(self._model_info)
656        if qy is not None:
657            q_vectors = [np.asarray(qx), np.asarray(qy)]
658        else:
659            q_vectors = [np.asarray(qx)]
660        calculator = self._model.make_kernel(q_vectors)
661        parameters = self._model_info.parameters
662        pairs = [self._get_weights(p) for p in parameters.call_parameters]
663        #weights.plot_weights(self._model_info, pairs)
664        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
665        #call_details.show()
666        #print("pairs", pairs)
667        #print("params", self.params)
668        #print("values", values)
669        #print("is_mag", is_magnetic)
670        result = calculator(call_details, values, cutoff=self.cutoff,
671                            magnetic=is_magnetic)
672        self._intermediate_results = getattr(calculator, 'results', None)
673        calculator.release()
674        self._model.release()
675        return result
676
677    def calculate_ER(self):
678        # type: () -> float
679        """
680        Calculate the effective radius for P(q)*S(q)
681
682        :return: the value of the effective radius
683        """
684        if self._model_info.ER is None:
685            return 1.0
686        else:
687            value, weight = self._dispersion_mesh()
688            fv = self._model_info.ER(*value)
689            #print(values[0].shape, weights.shape, fv.shape)
690            return np.sum(weight * fv) / np.sum(weight)
691
692    def calculate_VR(self):
693        # type: () -> float
694        """
695        Calculate the volf ratio for P(q)*S(q)
696
697        :return: the value of the volf ratio
698        """
699        if self._model_info.VR is None:
700            return 1.0
701        else:
702            value, weight = self._dispersion_mesh()
703            whole, part = self._model_info.VR(*value)
704            return np.sum(weight * part) / np.sum(weight * whole)
705
706    def set_dispersion(self, parameter, dispersion):
707        # type: (str, weights.Dispersion) -> Dict[str, Any]
708        """
709        Set the dispersion object for a model parameter
710
711        :param parameter: name of the parameter [string]
712        :param dispersion: dispersion object of type Dispersion
713        """
714        if parameter in self.params:
715            # TODO: Store the disperser object directly in the model.
716            # The current method of relying on the sasview GUI to
717            # remember them is kind of funky.
718            # Note: can't seem to get disperser parameters from sasview
719            # (1) Could create a sasview model that has not yet been
720            # converted, assign the disperser to one of its polydisperse
721            # parameters, then retrieve the disperser parameters from the
722            # sasview model.
723            # (2) Could write a disperser parameter retriever in sasview.
724            # (3) Could modify sasview to use sasmodels.weights dispersers.
725            # For now, rely on the fact that the sasview only ever uses
726            # new dispersers in the set_dispersion call and create a new
727            # one instead of trying to assign parameters.
728            self.dispersion[parameter] = dispersion.get_pars()
729        else:
730            raise ValueError("%r is not a dispersity or orientation parameter")
731
732    def _dispersion_mesh(self):
733        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
734        """
735        Create a mesh grid of dispersion parameters and weights.
736
737        Returns [p1,p2,...],w where pj is a vector of values for parameter j
738        and w is a vector containing the products for weights for each
739        parameter set in the vector.
740        """
741        pars = [self._get_weights(p)
742                for p in self._model_info.parameters.call_parameters
743                if p.type == 'volume']
744        return dispersion_mesh(self._model_info, pars)
745
746    def _get_weights(self, par):
747        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
748        """
749        Return dispersion weights for parameter
750        """
751        if par.name not in self.params:
752            if par.name == self.multiplicity_info.control:
753                return [self.multiplicity], [1.0]
754            else:
755                # For hidden parameters use the default value.
756                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
757                return [value], [1.0]
758        elif par.polydisperse:
759            dis = self.dispersion[par.name]
760            if dis['type'] == 'array':
761                value, weight = dis['values'], dis['weights']
762            else:
763                value, weight = weights.get_weights(
764                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
765                    self.params[par.name], par.limits, par.relative_pd)
766            return value, weight / np.sum(weight)
767        else:
768            return [self.params[par.name]], [1.0]
769
770def test_cylinder():
771    # type: () -> float
772    """
773    Test that the cylinder model runs, returning the value at [0.1,0.1].
774    """
775    Cylinder = _make_standard_model('cylinder')
776    cylinder = Cylinder()
777    return cylinder.evalDistribution([0.1, 0.1])
778
779def test_structure_factor():
780    # type: () -> float
781    """
782    Test that 2-D hardsphere model runs and doesn't produce NaN.
783    """
784    Model = _make_standard_model('hardsphere')
785    model = Model()
786    value = model.evalDistribution([0.1, 0.1])
787    if np.isnan(value):
788        raise ValueError("hardsphere returns null")
789
790def test_rpa():
791    # type: () -> float
792    """
793    Test that the 2-D RPA model runs
794    """
795    RPA = _make_standard_model('rpa')
796    rpa = RPA(3)
797    return rpa.evalDistribution([0.1, 0.1])
798
799def test_empty_distribution():
800    # type: () -> None
801    """
802    Make sure that sasmodels returns NaN when there are no polydispersity points
803    """
804    Cylinder = _make_standard_model('cylinder')
805    cylinder = Cylinder()
806    cylinder.setParam('radius', -1.0)
807    cylinder.setParam('background', 0.)
808    Iq = cylinder.evalDistribution(np.asarray([0.1]))
809    assert np.isnan(Iq[0]), "empty distribution fails"
810
811def test_model_list():
812    # type: () -> None
813    """
814    Make sure that all models build as sasview models
815    """
816    from .exception import annotate_exception
817    for name in core.list_models():
818        try:
819            _make_standard_model(name)
820        except:
821            annotate_exception("when loading "+name)
822            raise
823
824def test_old_name():
825    # type: () -> None
826    """
827    Load and run cylinder model from sas.models.CylinderModel
828    """
829    if not SUPPORT_OLD_STYLE_PLUGINS:
830        return
831    try:
832        # if sasview is not on the path then don't try to test it
833        import sas
834    except ImportError:
835        return
836    load_standard_models()
837    from sas.models.CylinderModel import CylinderModel
838    CylinderModel().evalDistribution([0.1, 0.1])
839
840if __name__ == "__main__":
841    print("cylinder(0.1,0.1)=%g"%test_cylinder())
842    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.