source: sasmodels/sasmodels/sasview_model.py @ 17db833

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 17db833 was 17db833, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

fix structure factor models, so scale=1, background=0 again

  • Property mode set to 100644
File size: 30.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33try:
34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
38        'MultiplicityInfo',
39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
41    SasviewModelType = Callable[[int], "SasviewModel"]
42except ImportError:
43    pass
44
45logger = logging.getLogger(__name__)
46
47calculation_lock = thread.allocate_lock()
48
49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
51SUPPORT_OLD_STYLE_PLUGINS = True
52
53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
64def find_model(modelname):
65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
79
80# TODO: figure out how to say that the return type is a subclass
81def load_standard_models():
82    # type: () -> List[SasviewModelType]
83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
91            MODELS[name] = _make_standard_model(name)
92        except Exception:
93            logger.error(traceback.format_exc())
94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
97    return list(MODELS.values())
98
99
100def load_custom_model(path):
101    # type: (str) -> SasviewModelType
102    """
103    Load a custom model given the model path.
104    """
105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
111    kernel_module = custom.load_custom_kernel_module(path)
112    if hasattr(kernel_module, 'Model'):
113        model = kernel_module.Model
114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
119        if not hasattr(model, 'filename'):
120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
123    else:
124        model_info = modelinfo.make_model_info(kernel_module)
125        model = make_model_from_info(model_info)
126    model.timestamp = getmtime(path)
127
128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
133
134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
141
142    MODELS[model.name] = model
143    MODEL_BY_PATH[path] = model
144    return model
145
146
147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
161def _make_standard_model(name):
162    # type: (str) -> SasviewModelType
163    """
164    Load the sasview model defined by *name*.
165
166    *name* can be a standard model name or a path to a custom model.
167
168    Returns a class that can be used directly as a sasview model.
169    """
170    kernel_module = generate.load_kernel_module(name)
171    model_info = modelinfo.make_model_info(kernel_module)
172    return make_model_from_info(model_info)
173
174
175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188
189    import sas.models
190    from sasmodels.conversion_table import CONVERSION_TABLE
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
206    ConstructedModel = make_model_from_info(model_info)
207    return ConstructedModel(form_factor.multiplicity)
208
209
210def _generate_model_attributes(model_info):
211    # type: (ModelInfo) -> Dict[str, Any]
212    """
213    Generate the class attributes for the model.
214
215    This should include all the information necessary to query the model
216    details so that you do not need to instantiate a model to query it.
217
218    All the attributes should be immutable to avoid accidents.
219    """
220
221    # TODO: allow model to override axis labels input/output name/unit
222
223    # Process multiplicity
224    non_fittable = []  # type: List[str]
225    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
226    variants = MultiplicityInfo(0, "", [], xlabel)
227    for p in model_info.parameters.kernel_parameters:
228        if p.name == model_info.control:
229            non_fittable.append(p.name)
230            variants = MultiplicityInfo(
231                len(p.choices) if p.choices else int(p.limits[1]),
232                p.name, p.choices, xlabel
233            )
234            break
235
236    # Only a single drop-down list parameter available
237    fun_list = []
238    for p in model_info.parameters.kernel_parameters:
239        if p.choices:
240            fun_list = p.choices
241            if p.length > 1:
242                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
243            break
244
245    # Organize parameter sets
246    orientation_params = []
247    magnetic_params = []
248    fixed = []
249    for p in model_info.parameters.user_parameters({}, is2d=True):
250        if p.type == 'orientation':
251            orientation_params.append(p.name)
252            orientation_params.append(p.name+".width")
253            fixed.append(p.name+".width")
254        elif p.type == 'magnetic':
255            orientation_params.append(p.name)
256            magnetic_params.append(p.name)
257            fixed.append(p.name+".width")
258
259
260    # Build class dictionary
261    attrs = {}  # type: Dict[str, Any]
262    attrs['_model_info'] = model_info
263    attrs['name'] = model_info.name
264    attrs['id'] = model_info.id
265    attrs['description'] = model_info.description
266    attrs['category'] = model_info.category
267    attrs['is_structure_factor'] = model_info.structure_factor
268    attrs['is_form_factor'] = model_info.ER is not None
269    attrs['is_multiplicity_model'] = variants[0] > 1
270    attrs['multiplicity_info'] = variants
271    attrs['orientation_params'] = tuple(orientation_params)
272    attrs['magnetic_params'] = tuple(magnetic_params)
273    attrs['fixed'] = tuple(fixed)
274    attrs['non_fittable'] = tuple(non_fittable)
275    attrs['fun_list'] = tuple(fun_list)
276
277    return attrs
278
279class SasviewModel(object):
280    """
281    Sasview wrapper for opencl/ctypes model.
282    """
283    # Model parameters for the specific model are set in the class constructor
284    # via the _generate_model_attributes function, which subclasses
285    # SasviewModel.  They are included here for typing and documentation
286    # purposes.
287    _model = None       # type: KernelModel
288    _model_info = None  # type: ModelInfo
289    #: load/save name for the model
290    id = None           # type: str
291    #: display name for the model
292    name = None         # type: str
293    #: short model description
294    description = None  # type: str
295    #: default model category
296    category = None     # type: str
297
298    #: names of the orientation parameters in the order they appear
299    orientation_params = None # type: List[str]
300    #: names of the magnetic parameters in the order they appear
301    magnetic_params = None    # type: List[str]
302    #: names of the fittable parameters
303    fixed = None              # type: List[str]
304    # TODO: the attribute fixed is ill-named
305
306    # Axis labels
307    input_name = "Q"
308    input_unit = "A^{-1}"
309    output_name = "Intensity"
310    output_unit = "cm^{-1}"
311
312    #: default cutoff for polydispersity
313    cutoff = 1e-5
314
315    # Note: Use non-mutable values for class attributes to avoid errors
316    #: parameters that are not fitted
317    non_fittable = ()        # type: Sequence[str]
318
319    #: True if model should appear as a structure factor
320    is_structure_factor = False
321    #: True if model should appear as a form factor
322    is_form_factor = False
323    #: True if model has multiplicity
324    is_multiplicity_model = False
325    #: Multiplicity information
326    multiplicity_info = None # type: MultiplicityInfoType
327
328    # Per-instance variables
329    #: parameter {name: value} mapping
330    params = None      # type: Dict[str, float]
331    #: values for dispersion width, npts, nsigmas and type
332    dispersion = None  # type: Dict[str, Any]
333    #: units and limits for each parameter
334    details = None     # type: Dict[str, Sequence[Any]]
335    #                  # actual type is Dict[str, List[str, float, float]]
336    #: multiplicity value, or None if no multiplicity on the model
337    multiplicity = None     # type: Optional[int]
338    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
339    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
340
341    def __init__(self, multiplicity=None):
342        # type: (Optional[int]) -> None
343
344        # TODO: _persistency_dict to persistency_dict throughout sasview
345        # TODO: refactor multiplicity to encompass variants
346        # TODO: dispersion should be a class
347        # TODO: refactor multiplicity info
348        # TODO: separate profile view from multiplicity
349        # The button label, x and y axis labels and scale need to be under
350        # the control of the model, not the fit page.  Maximum flexibility,
351        # the fit page would supply the canvas and the profile could plot
352        # how it wants, but this assumes matplotlib.  Next level is that
353        # we provide some sort of data description including title, labels
354        # and lines to plot.
355
356        # Get the list of hidden parameters given the multiplicity
357        # Don't include multiplicity in the list of parameters
358        self.multiplicity = multiplicity
359        if multiplicity is not None:
360            hidden = self._model_info.get_hidden_parameters(multiplicity)
361            hidden |= set([self.multiplicity_info.control])
362        else:
363            hidden = set()
364        if self._model_info.structure_factor:
365            hidden.add('scale')
366            hidden.add('background')
367            self._model_info.parameters.defaults['background'] = 0.
368
369        self._persistency_dict = {}
370        self.params = collections.OrderedDict()
371        self.dispersion = collections.OrderedDict()
372        self.details = {}
373        for p in self._model_info.parameters.user_parameters({}, is2d=True):
374            if p.name in hidden:
375                continue
376            self.params[p.name] = p.default
377            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
378            if p.polydisperse:
379                self.details[p.id+".width"] = [
380                    "", 0.0, 1.0 if p.relative_pd else np.inf
381                ]
382                self.dispersion[p.name] = {
383                    'width': 0,
384                    'npts': 35,
385                    'nsigmas': 3,
386                    'type': 'gaussian',
387                }
388
389    def __get_state__(self):
390        # type: () -> Dict[str, Any]
391        state = self.__dict__.copy()
392        state.pop('_model')
393        # May need to reload model info on set state since it has pointers
394        # to python implementations of Iq, etc.
395        #state.pop('_model_info')
396        return state
397
398    def __set_state__(self, state):
399        # type: (Dict[str, Any]) -> None
400        self.__dict__ = state
401        self._model = None
402
403    def __str__(self):
404        # type: () -> str
405        """
406        :return: string representation
407        """
408        return self.name
409
410    def is_fittable(self, par_name):
411        # type: (str) -> bool
412        """
413        Check if a given parameter is fittable or not
414
415        :param par_name: the parameter name to check
416        """
417        return par_name in self.fixed
418        #For the future
419        #return self.params[str(par_name)].is_fittable()
420
421
422    def getProfile(self):
423        # type: () -> (np.ndarray, np.ndarray)
424        """
425        Get SLD profile
426
427        : return: (z, beta) where z is a list of depth of the transition points
428                beta is a list of the corresponding SLD values
429        """
430        args = {} # type: Dict[str, Any]
431        for p in self._model_info.parameters.kernel_parameters:
432            if p.id == self.multiplicity_info.control:
433                value = float(self.multiplicity)
434            elif p.length == 1:
435                value = self.params.get(p.id, np.NaN)
436            else:
437                value = np.array([self.params.get(p.id+str(k), np.NaN)
438                                  for k in range(1, p.length+1)])
439            args[p.id] = value
440
441        x, y = self._model_info.profile(**args)
442        return x, 1e-6*y
443
444    def setParam(self, name, value):
445        # type: (str, float) -> None
446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450        :param value: value of the parameter
451
452        """
453        # Look for dispersion parameters
454        toks = name.split('.')
455        if len(toks) == 2:
456            for item in self.dispersion.keys():
457                if item == toks[0]:
458                    for par in self.dispersion[item]:
459                        if par == toks[1]:
460                            self.dispersion[item][par] = value
461                            return
462        else:
463            # Look for standard parameter
464            for item in self.params.keys():
465                if item == name:
466                    self.params[item] = value
467                    return
468
469        raise ValueError("Model does not contain parameter %s" % name)
470
471    def getParam(self, name):
472        # type: (str) -> float
473        """
474        Set the value of a model parameter
475
476        :param name: name of the parameter
477
478        """
479        # Look for dispersion parameters
480        toks = name.split('.')
481        if len(toks) == 2:
482            for item in self.dispersion.keys():
483                if item == toks[0]:
484                    for par in self.dispersion[item]:
485                        if par == toks[1]:
486                            return self.dispersion[item][par]
487        else:
488            # Look for standard parameter
489            for item in self.params.keys():
490                if item == name:
491                    return self.params[item]
492
493        raise ValueError("Model does not contain parameter %s" % name)
494
495    def getParamList(self):
496        # type: () -> Sequence[str]
497        """
498        Return a list of all available parameters for the model
499        """
500        param_list = list(self.params.keys())
501        # WARNING: Extending the list with the dispersion parameters
502        param_list.extend(self.getDispParamList())
503        return param_list
504
505    def getDispParamList(self):
506        # type: () -> Sequence[str]
507        """
508        Return a list of polydispersity parameters for the model
509        """
510        # TODO: fix test so that parameter order doesn't matter
511        ret = ['%s.%s' % (p_name, ext)
512               for p_name in self.dispersion.keys()
513               for ext in ('npts', 'nsigmas', 'width')]
514        #print(ret)
515        return ret
516
517    def clone(self):
518        # type: () -> "SasviewModel"
519        """ Return a identical copy of self """
520        return deepcopy(self)
521
522    def run(self, x=0.0):
523        # type: (Union[float, (float, float), List[float]]) -> float
524        """
525        Evaluate the model
526
527        :param x: input q, or [q,phi]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
533        if isinstance(x, (list, tuple)):
534            # pylint: disable=unpacking-non-sequence
535            q, phi = x
536            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
537        else:
538            return self.calculate_Iq([x])[0]
539
540
541    def runXY(self, x=0.0):
542        # type: (Union[float, (float, float), List[float]]) -> float
543        """
544        Evaluate the model in cartesian coordinates
545
546        :param x: input q, or [qx, qy]
547
548        :return: scattering function P(q)
549
550        **DEPRECATED**: use calculate_Iq instead
551        """
552        if isinstance(x, (list, tuple)):
553            return self.calculate_Iq([x[0]], [x[1]])[0]
554        else:
555            return self.calculate_Iq([x])[0]
556
557    def evalDistribution(self, qdist):
558        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
559        r"""
560        Evaluate a distribution of q-values.
561
562        :param qdist: array of q or a list of arrays [qx,qy]
563
564        * For 1D, a numpy array is expected as input
565
566        ::
567
568            evalDistribution(q)
569
570          where *q* is a numpy array.
571
572        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
573
574        ::
575
576              qx = [ qx[0], qx[1], qx[2], ....]
577              qy = [ qy[0], qy[1], qy[2], ....]
578
579        If the model is 1D only, then
580
581        .. math::
582
583            q = \sqrt{q_x^2+q_y^2}
584
585        """
586        if isinstance(qdist, (list, tuple)):
587            # Check whether we have a list of ndarrays [qx,qy]
588            qx, qy = qdist
589            if not self._model_info.parameters.has_2d:
590                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
591            else:
592                return self.calculate_Iq(qx, qy)
593
594        elif isinstance(qdist, np.ndarray):
595            # We have a simple 1D distribution of q-values
596            return self.calculate_Iq(qdist)
597
598        else:
599            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
600                            % type(qdist))
601
602    def calc_composition_models(self, qx):
603        """
604        returns parts of the composition model or None if not a composition
605        model.
606        """
607        # TODO: have calculate_Iq return the intermediates.
608        #
609        # The current interface causes calculate_Iq() to be called twice,
610        # once to get the combined result and again to get the intermediate
611        # results.  This is necessary for now.
612        # Long term, the solution is to change the interface to calculate_Iq
613        # so that it returns a results object containing all the bits:
614        #     the A, B, C, ... of the composition model (and any subcomponents?)
615        #     the P and S of the product model,
616        #     the combined model before resolution smearing,
617        #     the sasmodel before sesans conversion,
618        #     the oriented 2D model used to fit oriented usans data,
619        #     the final I(q),
620        #     ...
621        #
622        # Have the model calculator add all of these blindly to the data
623        # tree, and update the graphs which contain them.  The fitter
624        # needs to be updated to use the I(q) value only, ignoring the rest.
625        #
626        # The simple fix of returning the existing intermediate results
627        # will not work for a couple of reasons: (1) another thread may
628        # sneak in to compute its own results before calc_composition_models
629        # is called, and (2) calculate_Iq is currently called three times:
630        # once with q, once with q values before qmin and once with q values
631        # after q max.  Both of these should be addressed before
632        # replacing this code.
633        composition = self._model_info.composition
634        if composition and composition[0] == 'product': # only P*S for now
635            with calculation_lock:
636                self._calculate_Iq(qx)
637                return self._intermediate_results
638        else:
639            return None
640
641    def calculate_Iq(self, qx, qy=None):
642        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
643        """
644        Calculate Iq for one set of q with the current parameters.
645
646        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
647
648        This should NOT be used for fitting since it copies the *q* vectors
649        to the card for each evaluation.
650        """
651        ## uncomment the following when trying to debug the uncoordinated calls
652        ## to calculate_Iq
653        #if calculation_lock.locked():
654        #    logger.info("calculation waiting for another thread to complete")
655        #    logger.info("\n".join(traceback.format_stack()))
656
657        with calculation_lock:
658            return self._calculate_Iq(qx, qy)
659
660    def _calculate_Iq(self, qx, qy=None):
661        #core.HAVE_OPENCL = False
662        if self._model is None:
663            self._model = core.build_model(self._model_info)
664        if qy is not None:
665            q_vectors = [np.asarray(qx), np.asarray(qy)]
666        else:
667            q_vectors = [np.asarray(qx)]
668        calculator = self._model.make_kernel(q_vectors)
669        parameters = self._model_info.parameters
670        pairs = [self._get_weights(p) for p in parameters.call_parameters]
671        #weights.plot_weights(self._model_info, pairs)
672        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
673        #call_details.show()
674        #print("pairs", pairs)
675        #for k, p in enumerate(self._model_info.parameters.call_parameters):
676        #    print(k, p.name, *pairs[k])
677        #print("params", self.params)
678        #print("values", values)
679        #print("is_mag", is_magnetic)
680        result = calculator(call_details, values, cutoff=self.cutoff,
681                            magnetic=is_magnetic)
682        #print("result", result)
683        self._intermediate_results = getattr(calculator, 'results', None)
684        calculator.release()
685        self._model.release()
686        return result
687
688    def calculate_ER(self):
689        # type: () -> float
690        """
691        Calculate the effective radius for P(q)*S(q)
692
693        :return: the value of the effective radius
694        """
695        if self._model_info.ER is None:
696            return 1.0
697        else:
698            value, weight = self._dispersion_mesh()
699            fv = self._model_info.ER(*value)
700            #print(values[0].shape, weights.shape, fv.shape)
701            return np.sum(weight * fv) / np.sum(weight)
702
703    def calculate_VR(self):
704        # type: () -> float
705        """
706        Calculate the volf ratio for P(q)*S(q)
707
708        :return: the value of the volf ratio
709        """
710        if self._model_info.VR is None:
711            return 1.0
712        else:
713            value, weight = self._dispersion_mesh()
714            whole, part = self._model_info.VR(*value)
715            return np.sum(weight * part) / np.sum(weight * whole)
716
717    def set_dispersion(self, parameter, dispersion):
718        # type: (str, weights.Dispersion) -> Dict[str, Any]
719        """
720        Set the dispersion object for a model parameter
721
722        :param parameter: name of the parameter [string]
723        :param dispersion: dispersion object of type Dispersion
724        """
725        if parameter in self.params:
726            # TODO: Store the disperser object directly in the model.
727            # The current method of relying on the sasview GUI to
728            # remember them is kind of funky.
729            # Note: can't seem to get disperser parameters from sasview
730            # (1) Could create a sasview model that has not yet been
731            # converted, assign the disperser to one of its polydisperse
732            # parameters, then retrieve the disperser parameters from the
733            # sasview model.
734            # (2) Could write a disperser parameter retriever in sasview.
735            # (3) Could modify sasview to use sasmodels.weights dispersers.
736            # For now, rely on the fact that the sasview only ever uses
737            # new dispersers in the set_dispersion call and create a new
738            # one instead of trying to assign parameters.
739            self.dispersion[parameter] = dispersion.get_pars()
740        else:
741            raise ValueError("%r is not a dispersity or orientation parameter")
742
743    def _dispersion_mesh(self):
744        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
745        """
746        Create a mesh grid of dispersion parameters and weights.
747
748        Returns [p1,p2,...],w where pj is a vector of values for parameter j
749        and w is a vector containing the products for weights for each
750        parameter set in the vector.
751        """
752        pars = [self._get_weights(p)
753                for p in self._model_info.parameters.call_parameters
754                if p.type == 'volume']
755        return dispersion_mesh(self._model_info, pars)
756
757    def _get_weights(self, par):
758        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
759        """
760        Return dispersion weights for parameter
761        """
762        if par.name not in self.params:
763            if par.name == self.multiplicity_info.control:
764                return self.multiplicity, [self.multiplicity], [1.0]
765            else:
766                # For hidden parameters use default values.  This sets
767                # scale=1 and background=0 for structure factors
768                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
769                return default, [default], [1.0]
770        elif par.polydisperse:
771            value = self.params[par.name]
772            dis = self.dispersion[par.name]
773            if dis['type'] == 'array':
774                dispersity, weight = dis['values'], dis['weights']
775            else:
776                dispersity, weight = weights.get_weights(
777                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
778                    value, par.limits, par.relative_pd)
779            return value, dispersity, weight
780        else:
781            value = self.params[par.name]
782            return value, [value], [1.0]
783
784def test_cylinder():
785    # type: () -> float
786    """
787    Test that the cylinder model runs, returning the value at [0.1,0.1].
788    """
789    Cylinder = _make_standard_model('cylinder')
790    cylinder = Cylinder()
791    return cylinder.evalDistribution([0.1, 0.1])
792
793def test_structure_factor():
794    # type: () -> float
795    """
796    Test that 2-D hardsphere model runs and doesn't produce NaN.
797    """
798    Model = _make_standard_model('hardsphere')
799    model = Model()
800    value2d = model.evalDistribution([0.1, 0.1])
801    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
802    #print("hardsphere", value1d, value2d)
803    if np.isnan(value1d) or np.isnan(value2d):
804        raise ValueError("hardsphere returns nan")
805
806def test_product():
807    # type: () -> float
808    """
809    Test that 2-D hardsphere model runs and doesn't produce NaN.
810    """
811    S = _make_standard_model('hayter_msa')()
812    P = _make_standard_model('cylinder')()
813    model = MultiplicationModel(P, S)
814    value = model.evalDistribution([0.1, 0.1])
815    if np.isnan(value):
816        raise ValueError("cylinder*hatyer_msa returns null")
817
818def test_rpa():
819    # type: () -> float
820    """
821    Test that the 2-D RPA model runs
822    """
823    RPA = _make_standard_model('rpa')
824    rpa = RPA(3)
825    return rpa.evalDistribution([0.1, 0.1])
826
827def test_empty_distribution():
828    # type: () -> None
829    """
830    Make sure that sasmodels returns NaN when there are no polydispersity points
831    """
832    Cylinder = _make_standard_model('cylinder')
833    cylinder = Cylinder()
834    cylinder.setParam('radius', -1.0)
835    cylinder.setParam('background', 0.)
836    Iq = cylinder.evalDistribution(np.asarray([0.1]))
837    assert np.isnan(Iq[0]), "empty distribution fails"
838
839def test_model_list():
840    # type: () -> None
841    """
842    Make sure that all models build as sasview models
843    """
844    from .exception import annotate_exception
845    for name in core.list_models():
846        try:
847            _make_standard_model(name)
848        except:
849            annotate_exception("when loading "+name)
850            raise
851
852def test_old_name():
853    # type: () -> None
854    """
855    Load and run cylinder model from sas.models.CylinderModel
856    """
857    if not SUPPORT_OLD_STYLE_PLUGINS:
858        return
859    try:
860        # if sasview is not on the path then don't try to test it
861        import sas
862    except ImportError:
863        return
864    load_standard_models()
865    from sas.models.CylinderModel import CylinderModel
866    CylinderModel().evalDistribution([0.1, 0.1])
867
868if __name__ == "__main__":
869    print("cylinder(0.1,0.1)=%g"%test_cylinder())
870    #test_product()
871    #test_structure_factor()
872    #print("rpa:", test_rpa())
873    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.