source: sasmodels/sasmodels/sasview_model.py @ 9644b5a

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9644b5a was 9644b5a, checked in by Paul Kienzle <pkienzle@…>, 4 years ago

avoid recalculation of components unless it is a P@S product model

  • Property mode set to 100644
File size: 29.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18import thread
19
20import numpy as np  # type: ignore
21
22from . import core
23from . import custom
24from . import product
25from . import generate
26from . import weights
27from . import modelinfo
28from .details import make_kernel_args, dispersion_mesh
29
30try:
31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
38    SasviewModelType = Callable[[int], "SasviewModel"]
39except ImportError:
40    pass
41
42logger = logging.getLogger(__name__)
43
44calculation_lock = thread.allocate_lock()
45
46#: True if pre-existing plugins, with the old names and parameters, should
47#: continue to be supported.
48SUPPORT_OLD_STYLE_PLUGINS = True
49
50# TODO: separate x_axis_label from multiplicity info
51MultiplicityInfo = collections.namedtuple(
52    'MultiplicityInfo',
53    ["number", "control", "choices", "x_axis_label"],
54)
55
56#: set of defined models (standard and custom)
57MODELS = {}  # type: Dict[str, SasviewModelType]
58#: custom model {path: model} mapping so we can check timestamps
59MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
60
61def find_model(modelname):
62    # type: (str) -> SasviewModelType
63    """
64    Find a model by name.  If the model name ends in py, try loading it from
65    custom models, otherwise look for it in the list of builtin models.
66    """
67    # TODO: used by sum/product model to load an existing model
68    # TODO: doesn't handle custom models properly
69    if modelname.endswith('.py'):
70        return load_custom_model(modelname)
71    elif modelname in MODELS:
72        return MODELS[modelname]
73    else:
74        raise ValueError("unknown model %r"%modelname)
75
76
77# TODO: figure out how to say that the return type is a subclass
78def load_standard_models():
79    # type: () -> List[SasviewModelType]
80    """
81    Load and return the list of predefined models.
82
83    If there is an error loading a model, then a traceback is logged and the
84    model is not returned.
85    """
86    for name in core.list_models():
87        try:
88            MODELS[name] = _make_standard_model(name)
89        except Exception:
90            logger.error(traceback.format_exc())
91    if SUPPORT_OLD_STYLE_PLUGINS:
92        _register_old_models()
93
94    return list(MODELS.values())
95
96
97def load_custom_model(path):
98    # type: (str) -> SasviewModelType
99    """
100    Load a custom model given the model path.
101    """
102    model = MODEL_BY_PATH.get(path, None)
103    if model is not None and model.timestamp == getmtime(path):
104        #logger.info("Model already loaded %s", path)
105        return model
106
107    #logger.info("Loading model %s", path)
108    kernel_module = custom.load_custom_kernel_module(path)
109    if hasattr(kernel_module, 'Model'):
110        model = kernel_module.Model
111        # Old style models do not set the name in the class attributes, so
112        # set it here; this name will be overridden when the object is created
113        # with an instance variable that has the same value.
114        if model.name == "":
115            model.name = splitext(basename(path))[0]
116        if not hasattr(model, 'filename'):
117            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
118        if not hasattr(model, 'id'):
119            model.id = splitext(basename(model.filename))[0]
120    else:
121        model_info = modelinfo.make_model_info(kernel_module)
122        model = make_model_from_info(model_info)
123    model.timestamp = getmtime(path)
124
125    # If a model name already exists and we are loading a different model,
126    # use the model file name as the model name.
127    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
128        _previous_name = model.name
129        model.name = model.id
130
131        # If the new model name is still in the model list (for instance,
132        # if we put a cylinder.py in our plug-in directory), then append
133        # an identifier.
134        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
135            model.name = model.id + '_user'
136        logger.info("Model %s already exists: using %s [%s]",
137                    _previous_name, model.name, model.filename)
138
139    MODELS[model.name] = model
140    MODEL_BY_PATH[path] = model
141    return model
142
143
144def make_model_from_info(model_info):
145    # type: (ModelInfo) -> SasviewModelType
146    """
147    Convert *model_info* into a SasView model wrapper.
148    """
149    def __init__(self, multiplicity=None):
150        SasviewModel.__init__(self, multiplicity=multiplicity)
151    attrs = _generate_model_attributes(model_info)
152    attrs['__init__'] = __init__
153    attrs['filename'] = model_info.filename
154    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
155    return ConstructedModel
156
157
158def _make_standard_model(name):
159    # type: (str) -> SasviewModelType
160    """
161    Load the sasview model defined by *name*.
162
163    *name* can be a standard model name or a path to a custom model.
164
165    Returns a class that can be used directly as a sasview model.
166    """
167    kernel_module = generate.load_kernel_module(name)
168    model_info = modelinfo.make_model_info(kernel_module)
169    return make_model_from_info(model_info)
170
171
172def _register_old_models():
173    # type: () -> None
174    """
175    Place the new models into sasview under the old names.
176
177    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
178    is available to the plugin modules.
179    """
180    import sys
181    import sas   # needed in order to set sas.models
182    import sas.sascalc.fit
183    sys.modules['sas.models'] = sas.sascalc.fit
184    sas.models = sas.sascalc.fit
185
186    import sas.models
187    from sasmodels.conversion_table import CONVERSION_TABLE
188    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
189        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
190        new_name = new_name.split(':')[0]
191        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
192        module_attrs = {old_name: find_model(new_name)}
193        ConstructedModule = type(old_name, (), module_attrs)
194        old_path = 'sas.models.' + old_name
195        setattr(sas.models, old_path, ConstructedModule)
196        sys.modules[old_path] = ConstructedModule
197
198
199def MultiplicationModel(form_factor, structure_factor):
200    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
201    model_info = product.make_product_info(form_factor._model_info,
202                                           structure_factor._model_info)
203    ConstructedModel = make_model_from_info(model_info)
204    return ConstructedModel()
205
206
207def _generate_model_attributes(model_info):
208    # type: (ModelInfo) -> Dict[str, Any]
209    """
210    Generate the class attributes for the model.
211
212    This should include all the information necessary to query the model
213    details so that you do not need to instantiate a model to query it.
214
215    All the attributes should be immutable to avoid accidents.
216    """
217
218    # TODO: allow model to override axis labels input/output name/unit
219
220    # Process multiplicity
221    non_fittable = []  # type: List[str]
222    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
223    variants = MultiplicityInfo(0, "", [], xlabel)
224    for p in model_info.parameters.kernel_parameters:
225        if p.name == model_info.control:
226            non_fittable.append(p.name)
227            variants = MultiplicityInfo(
228                len(p.choices) if p.choices else int(p.limits[1]),
229                p.name, p.choices, xlabel
230            )
231            break
232
233    # Only a single drop-down list parameter available
234    fun_list = []
235    for p in model_info.parameters.kernel_parameters:
236        if p.choices:
237            fun_list = p.choices
238            if p.length > 1:
239                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
240            break
241
242    # Organize parameter sets
243    orientation_params = []
244    magnetic_params = []
245    fixed = []
246    for p in model_info.parameters.user_parameters({}, is2d=True):
247        if p.type == 'orientation':
248            orientation_params.append(p.name)
249            orientation_params.append(p.name+".width")
250            fixed.append(p.name+".width")
251        elif p.type == 'magnetic':
252            orientation_params.append(p.name)
253            magnetic_params.append(p.name)
254            fixed.append(p.name+".width")
255
256
257    # Build class dictionary
258    attrs = {}  # type: Dict[str, Any]
259    attrs['_model_info'] = model_info
260    attrs['name'] = model_info.name
261    attrs['id'] = model_info.id
262    attrs['description'] = model_info.description
263    attrs['category'] = model_info.category
264    attrs['is_structure_factor'] = model_info.structure_factor
265    attrs['is_form_factor'] = model_info.ER is not None
266    attrs['is_multiplicity_model'] = variants[0] > 1
267    attrs['multiplicity_info'] = variants
268    attrs['orientation_params'] = tuple(orientation_params)
269    attrs['magnetic_params'] = tuple(magnetic_params)
270    attrs['fixed'] = tuple(fixed)
271    attrs['non_fittable'] = tuple(non_fittable)
272    attrs['fun_list'] = tuple(fun_list)
273
274    return attrs
275
276class SasviewModel(object):
277    """
278    Sasview wrapper for opencl/ctypes model.
279    """
280    # Model parameters for the specific model are set in the class constructor
281    # via the _generate_model_attributes function, which subclasses
282    # SasviewModel.  They are included here for typing and documentation
283    # purposes.
284    _model = None       # type: KernelModel
285    _model_info = None  # type: ModelInfo
286    #: load/save name for the model
287    id = None           # type: str
288    #: display name for the model
289    name = None         # type: str
290    #: short model description
291    description = None  # type: str
292    #: default model category
293    category = None     # type: str
294
295    #: names of the orientation parameters in the order they appear
296    orientation_params = None # type: List[str]
297    #: names of the magnetic parameters in the order they appear
298    magnetic_params = None    # type: List[str]
299    #: names of the fittable parameters
300    fixed = None              # type: List[str]
301    # TODO: the attribute fixed is ill-named
302
303    # Axis labels
304    input_name = "Q"
305    input_unit = "A^{-1}"
306    output_name = "Intensity"
307    output_unit = "cm^{-1}"
308
309    #: default cutoff for polydispersity
310    cutoff = 1e-5
311
312    # Note: Use non-mutable values for class attributes to avoid errors
313    #: parameters that are not fitted
314    non_fittable = ()        # type: Sequence[str]
315
316    #: True if model should appear as a structure factor
317    is_structure_factor = False
318    #: True if model should appear as a form factor
319    is_form_factor = False
320    #: True if model has multiplicity
321    is_multiplicity_model = False
322    #: Mulitplicity information
323    multiplicity_info = None # type: MultiplicityInfoType
324
325    # Per-instance variables
326    #: parameter {name: value} mapping
327    params = None      # type: Dict[str, float]
328    #: values for dispersion width, npts, nsigmas and type
329    dispersion = None  # type: Dict[str, Any]
330    #: units and limits for each parameter
331    details = None     # type: Dict[str, Sequence[Any]]
332    #                  # actual type is Dict[str, List[str, float, float]]
333    #: multiplicity value, or None if no multiplicity on the model
334    multiplicity = None     # type: Optional[int]
335    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
336    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
337
338    def __init__(self, multiplicity=None):
339        # type: (Optional[int]) -> None
340
341        # TODO: _persistency_dict to persistency_dict throughout sasview
342        # TODO: refactor multiplicity to encompass variants
343        # TODO: dispersion should be a class
344        # TODO: refactor multiplicity info
345        # TODO: separate profile view from multiplicity
346        # The button label, x and y axis labels and scale need to be under
347        # the control of the model, not the fit page.  Maximum flexibility,
348        # the fit page would supply the canvas and the profile could plot
349        # how it wants, but this assumes matplotlib.  Next level is that
350        # we provide some sort of data description including title, labels
351        # and lines to plot.
352
353        # Get the list of hidden parameters given the mulitplicity
354        # Don't include multiplicity in the list of parameters
355        self.multiplicity = multiplicity
356        if multiplicity is not None:
357            hidden = self._model_info.get_hidden_parameters(multiplicity)
358            hidden |= set([self.multiplicity_info.control])
359        else:
360            hidden = set()
361        if self._model_info.structure_factor:
362            hidden.add('scale')
363            hidden.add('background')
364            self._model_info.parameters.defaults['background'] = 0.
365
366        self._persistency_dict = {}
367        self.params = collections.OrderedDict()
368        self.dispersion = collections.OrderedDict()
369        self.details = {}
370        for p in self._model_info.parameters.user_parameters({}, is2d=True):
371            if p.name in hidden:
372                continue
373            self.params[p.name] = p.default
374            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
375            if p.polydisperse:
376                self.details[p.id+".width"] = [
377                    "", 0.0, 1.0 if p.relative_pd else np.inf
378                ]
379                self.dispersion[p.name] = {
380                    'width': 0,
381                    'npts': 35,
382                    'nsigmas': 3,
383                    'type': 'gaussian',
384                }
385
386    def __get_state__(self):
387        # type: () -> Dict[str, Any]
388        state = self.__dict__.copy()
389        state.pop('_model')
390        # May need to reload model info on set state since it has pointers
391        # to python implementations of Iq, etc.
392        #state.pop('_model_info')
393        return state
394
395    def __set_state__(self, state):
396        # type: (Dict[str, Any]) -> None
397        self.__dict__ = state
398        self._model = None
399
400    def __str__(self):
401        # type: () -> str
402        """
403        :return: string representation
404        """
405        return self.name
406
407    def is_fittable(self, par_name):
408        # type: (str) -> bool
409        """
410        Check if a given parameter is fittable or not
411
412        :param par_name: the parameter name to check
413        """
414        return par_name in self.fixed
415        #For the future
416        #return self.params[str(par_name)].is_fittable()
417
418
419    def getProfile(self):
420        # type: () -> (np.ndarray, np.ndarray)
421        """
422        Get SLD profile
423
424        : return: (z, beta) where z is a list of depth of the transition points
425                beta is a list of the corresponding SLD values
426        """
427        args = {} # type: Dict[str, Any]
428        for p in self._model_info.parameters.kernel_parameters:
429            if p.id == self.multiplicity_info.control:
430                value = float(self.multiplicity)
431            elif p.length == 1:
432                value = self.params.get(p.id, np.NaN)
433            else:
434                value = np.array([self.params.get(p.id+str(k), np.NaN)
435                                  for k in range(1, p.length+1)])
436            args[p.id] = value
437
438        x, y = self._model_info.profile(**args)
439        return x, 1e-6*y
440
441    def setParam(self, name, value):
442        # type: (str, float) -> None
443        """
444        Set the value of a model parameter
445
446        :param name: name of the parameter
447        :param value: value of the parameter
448
449        """
450        # Look for dispersion parameters
451        toks = name.split('.')
452        if len(toks) == 2:
453            for item in self.dispersion.keys():
454                if item == toks[0]:
455                    for par in self.dispersion[item]:
456                        if par == toks[1]:
457                            self.dispersion[item][par] = value
458                            return
459        else:
460            # Look for standard parameter
461            for item in self.params.keys():
462                if item == name:
463                    self.params[item] = value
464                    return
465
466        raise ValueError("Model does not contain parameter %s" % name)
467
468    def getParam(self, name):
469        # type: (str) -> float
470        """
471        Set the value of a model parameter
472
473        :param name: name of the parameter
474
475        """
476        # Look for dispersion parameters
477        toks = name.split('.')
478        if len(toks) == 2:
479            for item in self.dispersion.keys():
480                if item == toks[0]:
481                    for par in self.dispersion[item]:
482                        if par == toks[1]:
483                            return self.dispersion[item][par]
484        else:
485            # Look for standard parameter
486            for item in self.params.keys():
487                if item == name:
488                    return self.params[item]
489
490        raise ValueError("Model does not contain parameter %s" % name)
491
492    def getParamList(self):
493        # type: () -> Sequence[str]
494        """
495        Return a list of all available parameters for the model
496        """
497        param_list = list(self.params.keys())
498        # WARNING: Extending the list with the dispersion parameters
499        param_list.extend(self.getDispParamList())
500        return param_list
501
502    def getDispParamList(self):
503        # type: () -> Sequence[str]
504        """
505        Return a list of polydispersity parameters for the model
506        """
507        # TODO: fix test so that parameter order doesn't matter
508        ret = ['%s.%s' % (p_name, ext)
509               for p_name in self.dispersion.keys()
510               for ext in ('npts', 'nsigmas', 'width')]
511        #print(ret)
512        return ret
513
514    def clone(self):
515        # type: () -> "SasviewModel"
516        """ Return a identical copy of self """
517        return deepcopy(self)
518
519    def run(self, x=0.0):
520        # type: (Union[float, (float, float), List[float]]) -> float
521        """
522        Evaluate the model
523
524        :param x: input q, or [q,phi]
525
526        :return: scattering function P(q)
527
528        **DEPRECATED**: use calculate_Iq instead
529        """
530        if isinstance(x, (list, tuple)):
531            # pylint: disable=unpacking-non-sequence
532            q, phi = x
533            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
534        else:
535            return self.calculate_Iq([x])[0]
536
537
538    def runXY(self, x=0.0):
539        # type: (Union[float, (float, float), List[float]]) -> float
540        """
541        Evaluate the model in cartesian coordinates
542
543        :param x: input q, or [qx, qy]
544
545        :return: scattering function P(q)
546
547        **DEPRECATED**: use calculate_Iq instead
548        """
549        if isinstance(x, (list, tuple)):
550            return self.calculate_Iq([x[0]], [x[1]])[0]
551        else:
552            return self.calculate_Iq([x])[0]
553
554    def evalDistribution(self, qdist):
555        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
556        r"""
557        Evaluate a distribution of q-values.
558
559        :param qdist: array of q or a list of arrays [qx,qy]
560
561        * For 1D, a numpy array is expected as input
562
563        ::
564
565            evalDistribution(q)
566
567          where *q* is a numpy array.
568
569        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
570
571        ::
572
573              qx = [ qx[0], qx[1], qx[2], ....]
574              qy = [ qy[0], qy[1], qy[2], ....]
575
576        If the model is 1D only, then
577
578        .. math::
579
580            q = \sqrt{q_x^2+q_y^2}
581
582        """
583        if isinstance(qdist, (list, tuple)):
584            # Check whether we have a list of ndarrays [qx,qy]
585            qx, qy = qdist
586            if not self._model_info.parameters.has_2d:
587                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
588            else:
589                return self.calculate_Iq(qx, qy)
590
591        elif isinstance(qdist, np.ndarray):
592            # We have a simple 1D distribution of q-values
593            return self.calculate_Iq(qdist)
594
595        else:
596            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
597                            % type(qdist))
598
599    def calc_composition_models(self, qx):
600        """
601        returns parts of the composition model or None if not a composition
602        model.
603        """
604        # TODO: have calculate_Iq return the intermediates.
605        #
606        # The current interface causes calculate_Iq() to be called twice,
607        # once to get the combined result and again to get the intermediate
608        # results.  This is necessary for now.
609        # Long term, the solution is to change the interface to calculate_Iq
610        # so that it returns a results object containing all the bits:
611        #     the A, B, C, ... of the composition model (and any subcomponents?)
612        #     the P and S of the product model,
613        #     the combined model before resolution smearing,
614        #     the sasmodel before sesans conversion,
615        #     the oriented 2D model used to fit oriented usans data,
616        #     the final I(q),
617        #     ...
618        #
619        # Have the model calculator add all of these blindly to the data
620        # tree, and update the graphs which contain them.  The fitter
621        # needs to be updated to use the I(q) value only, ignoring the rest.
622        #
623        # The simple fix of returning the existing intermediate results
624        # will not work for a couple of reasons: (1) another thread may
625        # sneak in to compute its own results before calc_composition_models
626        # is called, and (2) calculate_Iq is currently called three times:
627        # once with q, once with q values before qmin and once with q values
628        # after q max.  Both of these should be addressed before
629        # replacing this code.
630        composition = self._model_info.composition
631        if composition and composition[0] == 'product': # only P*S for now
632            with calculation_lock:
633                self._calculate_Iq(qx)
634                return self._intermediate_results
635        else:
636            return None
637
638    def calculate_Iq(self, qx, qy=None):
639        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
640        """
641        Calculate Iq for one set of q with the current parameters.
642
643        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
644
645        This should NOT be used for fitting since it copies the *q* vectors
646        to the card for each evaluation.
647        """
648        ## uncomment the following when trying to debug the uncoordinated calls
649        ## to calculate_Iq
650        #if calculation_lock.locked():
651        #    logger.info("calculation waiting for another thread to complete")
652        #    logger.info("\n".join(traceback.format_stack()))
653
654        with calculation_lock:
655            return self._calculate_Iq(qx, qy)
656
657    def _calculate_Iq(self, qx, qy=None):
658        #core.HAVE_OPENCL = False
659        if self._model is None:
660            self._model = core.build_model(self._model_info)
661        if qy is not None:
662            q_vectors = [np.asarray(qx), np.asarray(qy)]
663        else:
664            q_vectors = [np.asarray(qx)]
665        calculator = self._model.make_kernel(q_vectors)
666        parameters = self._model_info.parameters
667        pairs = [self._get_weights(p) for p in parameters.call_parameters]
668        #weights.plot_weights(self._model_info, pairs)
669        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
670        #call_details.show()
671        #print("pairs", pairs)
672        #print("params", self.params)
673        #print("values", values)
674        #print("is_mag", is_magnetic)
675        result = calculator(call_details, values, cutoff=self.cutoff,
676                            magnetic=is_magnetic)
677        self._intermediate_results = getattr(calculator, 'results', None)
678        calculator.release()
679        self._model.release()
680        return result
681
682    def calculate_ER(self):
683        # type: () -> float
684        """
685        Calculate the effective radius for P(q)*S(q)
686
687        :return: the value of the effective radius
688        """
689        if self._model_info.ER is None:
690            return 1.0
691        else:
692            value, weight = self._dispersion_mesh()
693            fv = self._model_info.ER(*value)
694            #print(values[0].shape, weights.shape, fv.shape)
695            return np.sum(weight * fv) / np.sum(weight)
696
697    def calculate_VR(self):
698        # type: () -> float
699        """
700        Calculate the volf ratio for P(q)*S(q)
701
702        :return: the value of the volf ratio
703        """
704        if self._model_info.VR is None:
705            return 1.0
706        else:
707            value, weight = self._dispersion_mesh()
708            whole, part = self._model_info.VR(*value)
709            return np.sum(weight * part) / np.sum(weight * whole)
710
711    def set_dispersion(self, parameter, dispersion):
712        # type: (str, weights.Dispersion) -> Dict[str, Any]
713        """
714        Set the dispersion object for a model parameter
715
716        :param parameter: name of the parameter [string]
717        :param dispersion: dispersion object of type Dispersion
718        """
719        if parameter in self.params:
720            # TODO: Store the disperser object directly in the model.
721            # The current method of relying on the sasview GUI to
722            # remember them is kind of funky.
723            # Note: can't seem to get disperser parameters from sasview
724            # (1) Could create a sasview model that has not yet been
725            # converted, assign the disperser to one of its polydisperse
726            # parameters, then retrieve the disperser parameters from the
727            # sasview model.
728            # (2) Could write a disperser parameter retriever in sasview.
729            # (3) Could modify sasview to use sasmodels.weights dispersers.
730            # For now, rely on the fact that the sasview only ever uses
731            # new dispersers in the set_dispersion call and create a new
732            # one instead of trying to assign parameters.
733            self.dispersion[parameter] = dispersion.get_pars()
734        else:
735            raise ValueError("%r is not a dispersity or orientation parameter")
736
737    def _dispersion_mesh(self):
738        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
739        """
740        Create a mesh grid of dispersion parameters and weights.
741
742        Returns [p1,p2,...],w where pj is a vector of values for parameter j
743        and w is a vector containing the products for weights for each
744        parameter set in the vector.
745        """
746        pars = [self._get_weights(p)
747                for p in self._model_info.parameters.call_parameters
748                if p.type == 'volume']
749        return dispersion_mesh(self._model_info, pars)
750
751    def _get_weights(self, par):
752        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
753        """
754        Return dispersion weights for parameter
755        """
756        if par.name not in self.params:
757            if par.name == self.multiplicity_info.control:
758                return [self.multiplicity], [1.0]
759            else:
760                # For hidden parameters use the default value.
761                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
762                return [value], [1.0]
763        elif par.polydisperse:
764            dis = self.dispersion[par.name]
765            if dis['type'] == 'array':
766                value, weight = dis['values'], dis['weights']
767            else:
768                value, weight = weights.get_weights(
769                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
770                    self.params[par.name], par.limits, par.relative_pd)
771            return value, weight / np.sum(weight)
772        else:
773            return [self.params[par.name]], [1.0]
774
775def test_cylinder():
776    # type: () -> float
777    """
778    Test that the cylinder model runs, returning the value at [0.1,0.1].
779    """
780    Cylinder = _make_standard_model('cylinder')
781    cylinder = Cylinder()
782    return cylinder.evalDistribution([0.1, 0.1])
783
784def test_structure_factor():
785    # type: () -> float
786    """
787    Test that 2-D hardsphere model runs and doesn't produce NaN.
788    """
789    Model = _make_standard_model('hardsphere')
790    model = Model()
791    value = model.evalDistribution([0.1, 0.1])
792    if np.isnan(value):
793        raise ValueError("hardsphere returns null")
794
795def test_rpa():
796    # type: () -> float
797    """
798    Test that the 2-D RPA model runs
799    """
800    RPA = _make_standard_model('rpa')
801    rpa = RPA(3)
802    return rpa.evalDistribution([0.1, 0.1])
803
804def test_empty_distribution():
805    # type: () -> None
806    """
807    Make sure that sasmodels returns NaN when there are no polydispersity points
808    """
809    Cylinder = _make_standard_model('cylinder')
810    cylinder = Cylinder()
811    cylinder.setParam('radius', -1.0)
812    cylinder.setParam('background', 0.)
813    Iq = cylinder.evalDistribution(np.asarray([0.1]))
814    assert np.isnan(Iq[0]), "empty distribution fails"
815
816def test_model_list():
817    # type: () -> None
818    """
819    Make sure that all models build as sasview models
820    """
821    from .exception import annotate_exception
822    for name in core.list_models():
823        try:
824            _make_standard_model(name)
825        except:
826            annotate_exception("when loading "+name)
827            raise
828
829def test_old_name():
830    # type: () -> None
831    """
832    Load and run cylinder model from sas.models.CylinderModel
833    """
834    if not SUPPORT_OLD_STYLE_PLUGINS:
835        return
836    try:
837        # if sasview is not on the path then don't try to test it
838        import sas
839    except ImportError:
840        return
841    load_standard_models()
842    from sas.models.CylinderModel import CylinderModel
843    CylinderModel().evalDistribution([0.1, 0.1])
844
845if __name__ == "__main__":
846    print("cylinder(0.1,0.1)=%g"%test_cylinder())
847    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.