source: sasmodels/sasmodels/sasview_model.py @ e65c3ba

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since e65c3ba was e65c3ba, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

lint

  • Property mode set to 100644
File size: 30.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33try:
34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
38        'MultiplicityInfo',
39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
41    SasviewModelType = Callable[[int], "SasviewModel"]
42except ImportError:
43    pass
44
45logger = logging.getLogger(__name__)
46
47calculation_lock = thread.allocate_lock()
48
49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
51SUPPORT_OLD_STYLE_PLUGINS = True
52
53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
64def find_model(modelname):
65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
79
80# TODO: figure out how to say that the return type is a subclass
81def load_standard_models():
82    # type: () -> List[SasviewModelType]
83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
91            MODELS[name] = _make_standard_model(name)
92        except Exception:
93            logger.error(traceback.format_exc())
94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
97    return list(MODELS.values())
98
99
100def load_custom_model(path):
101    # type: (str) -> SasviewModelType
102    """
103    Load a custom model given the model path.
104    """
105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
111    kernel_module = custom.load_custom_kernel_module(path)
112    if hasattr(kernel_module, 'Model'):
113        model = kernel_module.Model
114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
119        if not hasattr(model, 'filename'):
120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
123    else:
124        model_info = modelinfo.make_model_info(kernel_module)
125        model = make_model_from_info(model_info)
126    model.timestamp = getmtime(path)
127
128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
133
134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
141
142    MODELS[model.name] = model
143    MODEL_BY_PATH[path] = model
144    return model
145
146
147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
161def _make_standard_model(name):
162    # type: (str) -> SasviewModelType
163    """
164    Load the sasview model defined by *name*.
165
166    *name* can be a standard model name or a path to a custom model.
167
168    Returns a class that can be used directly as a sasview model.
169    """
170    kernel_module = generate.load_kernel_module(name)
171    model_info = modelinfo.make_model_info(kernel_module)
172    return make_model_from_info(model_info)
173
174
175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188    import sas.models
189    from sasmodels.conversion_table import CONVERSION_TABLE
190
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    """
205    Returns a constructed product model from form_factor and structure_factor.
206    """
207    model_info = product.make_product_info(form_factor._model_info,
208                                           structure_factor._model_info)
209    ConstructedModel = make_model_from_info(model_info)
210    return ConstructedModel(form_factor.multiplicity)
211
212
213def _generate_model_attributes(model_info):
214    # type: (ModelInfo) -> Dict[str, Any]
215    """
216    Generate the class attributes for the model.
217
218    This should include all the information necessary to query the model
219    details so that you do not need to instantiate a model to query it.
220
221    All the attributes should be immutable to avoid accidents.
222    """
223
224    # TODO: allow model to override axis labels input/output name/unit
225
226    # Process multiplicity
227    non_fittable = []  # type: List[str]
228    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
229    variants = MultiplicityInfo(0, "", [], xlabel)
230    for p in model_info.parameters.kernel_parameters:
231        if p.name == model_info.control:
232            non_fittable.append(p.name)
233            variants = MultiplicityInfo(
234                len(p.choices) if p.choices else int(p.limits[1]),
235                p.name, p.choices, xlabel
236            )
237            break
238
239    # Only a single drop-down list parameter available
240    fun_list = []
241    for p in model_info.parameters.kernel_parameters:
242        if p.choices:
243            fun_list = p.choices
244            if p.length > 1:
245                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
246            break
247
248    # Organize parameter sets
249    orientation_params = []
250    magnetic_params = []
251    fixed = []
252    for p in model_info.parameters.user_parameters({}, is2d=True):
253        if p.type == 'orientation':
254            orientation_params.append(p.name)
255            orientation_params.append(p.name+".width")
256            fixed.append(p.name+".width")
257        elif p.type == 'magnetic':
258            orientation_params.append(p.name)
259            magnetic_params.append(p.name)
260            fixed.append(p.name+".width")
261
262
263    # Build class dictionary
264    attrs = {}  # type: Dict[str, Any]
265    attrs['_model_info'] = model_info
266    attrs['name'] = model_info.name
267    attrs['id'] = model_info.id
268    attrs['description'] = model_info.description
269    attrs['category'] = model_info.category
270    attrs['is_structure_factor'] = model_info.structure_factor
271    attrs['is_form_factor'] = model_info.ER is not None
272    attrs['is_multiplicity_model'] = variants[0] > 1
273    attrs['multiplicity_info'] = variants
274    attrs['orientation_params'] = tuple(orientation_params)
275    attrs['magnetic_params'] = tuple(magnetic_params)
276    attrs['fixed'] = tuple(fixed)
277    attrs['non_fittable'] = tuple(non_fittable)
278    attrs['fun_list'] = tuple(fun_list)
279
280    return attrs
281
282class SasviewModel(object):
283    """
284    Sasview wrapper for opencl/ctypes model.
285    """
286    # Model parameters for the specific model are set in the class constructor
287    # via the _generate_model_attributes function, which subclasses
288    # SasviewModel.  They are included here for typing and documentation
289    # purposes.
290    _model = None       # type: KernelModel
291    _model_info = None  # type: ModelInfo
292    #: load/save name for the model
293    id = None           # type: str
294    #: display name for the model
295    name = None         # type: str
296    #: short model description
297    description = None  # type: str
298    #: default model category
299    category = None     # type: str
300
301    #: names of the orientation parameters in the order they appear
302    orientation_params = None # type: List[str]
303    #: names of the magnetic parameters in the order they appear
304    magnetic_params = None    # type: List[str]
305    #: names of the fittable parameters
306    fixed = None              # type: List[str]
307    # TODO: the attribute fixed is ill-named
308
309    # Axis labels
310    input_name = "Q"
311    input_unit = "A^{-1}"
312    output_name = "Intensity"
313    output_unit = "cm^{-1}"
314
315    #: default cutoff for polydispersity
316    cutoff = 1e-5
317
318    # Note: Use non-mutable values for class attributes to avoid errors
319    #: parameters that are not fitted
320    non_fittable = ()        # type: Sequence[str]
321
322    #: True if model should appear as a structure factor
323    is_structure_factor = False
324    #: True if model should appear as a form factor
325    is_form_factor = False
326    #: True if model has multiplicity
327    is_multiplicity_model = False
328    #: Multiplicity information
329    multiplicity_info = None # type: MultiplicityInfoType
330
331    # Per-instance variables
332    #: parameter {name: value} mapping
333    params = None      # type: Dict[str, float]
334    #: values for dispersion width, npts, nsigmas and type
335    dispersion = None  # type: Dict[str, Any]
336    #: units and limits for each parameter
337    details = None     # type: Dict[str, Sequence[Any]]
338    #                  # actual type is Dict[str, List[str, float, float]]
339    #: multiplicity value, or None if no multiplicity on the model
340    multiplicity = None     # type: Optional[int]
341    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
342    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
343
344    def __init__(self, multiplicity=None):
345        # type: (Optional[int]) -> None
346
347        # TODO: _persistency_dict to persistency_dict throughout sasview
348        # TODO: refactor multiplicity to encompass variants
349        # TODO: dispersion should be a class
350        # TODO: refactor multiplicity info
351        # TODO: separate profile view from multiplicity
352        # The button label, x and y axis labels and scale need to be under
353        # the control of the model, not the fit page.  Maximum flexibility,
354        # the fit page would supply the canvas and the profile could plot
355        # how it wants, but this assumes matplotlib.  Next level is that
356        # we provide some sort of data description including title, labels
357        # and lines to plot.
358
359        # Get the list of hidden parameters given the multiplicity
360        # Don't include multiplicity in the list of parameters
361        self.multiplicity = multiplicity
362        if multiplicity is not None:
363            hidden = self._model_info.get_hidden_parameters(multiplicity)
364            hidden |= set([self.multiplicity_info.control])
365        else:
366            hidden = set()
367        if self._model_info.structure_factor:
368            hidden.add('scale')
369            hidden.add('background')
370            self._model_info.parameters.defaults['background'] = 0.
371
372        self._persistency_dict = {}
373        self.params = collections.OrderedDict()
374        self.dispersion = collections.OrderedDict()
375        self.details = {}
376        for p in self._model_info.parameters.user_parameters({}, is2d=True):
377            if p.name in hidden:
378                continue
379            self.params[p.name] = p.default
380            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
381            if p.polydisperse:
382                self.details[p.id+".width"] = [
383                    "", 0.0, 1.0 if p.relative_pd else np.inf
384                ]
385                self.dispersion[p.name] = {
386                    'width': 0,
387                    'npts': 35,
388                    'nsigmas': 3,
389                    'type': 'gaussian',
390                }
391
392    def __get_state__(self):
393        # type: () -> Dict[str, Any]
394        state = self.__dict__.copy()
395        state.pop('_model')
396        # May need to reload model info on set state since it has pointers
397        # to python implementations of Iq, etc.
398        #state.pop('_model_info')
399        return state
400
401    def __set_state__(self, state):
402        # type: (Dict[str, Any]) -> None
403        self.__dict__ = state
404        self._model = None
405
406    def __str__(self):
407        # type: () -> str
408        """
409        :return: string representation
410        """
411        return self.name
412
413    def is_fittable(self, par_name):
414        # type: (str) -> bool
415        """
416        Check if a given parameter is fittable or not
417
418        :param par_name: the parameter name to check
419        """
420        return par_name in self.fixed
421        #For the future
422        #return self.params[str(par_name)].is_fittable()
423
424
425    def getProfile(self):
426        # type: () -> (np.ndarray, np.ndarray)
427        """
428        Get SLD profile
429
430        : return: (z, beta) where z is a list of depth of the transition points
431                beta is a list of the corresponding SLD values
432        """
433        args = {} # type: Dict[str, Any]
434        for p in self._model_info.parameters.kernel_parameters:
435            if p.id == self.multiplicity_info.control:
436                value = float(self.multiplicity)
437            elif p.length == 1:
438                value = self.params.get(p.id, np.NaN)
439            else:
440                value = np.array([self.params.get(p.id+str(k), np.NaN)
441                                  for k in range(1, p.length+1)])
442            args[p.id] = value
443
444        x, y = self._model_info.profile(**args)
445        return x, 1e-6*y
446
447    def setParam(self, name, value):
448        # type: (str, float) -> None
449        """
450        Set the value of a model parameter
451
452        :param name: name of the parameter
453        :param value: value of the parameter
454
455        """
456        # Look for dispersion parameters
457        toks = name.split('.')
458        if len(toks) == 2:
459            for item in self.dispersion.keys():
460                if item == toks[0]:
461                    for par in self.dispersion[item]:
462                        if par == toks[1]:
463                            self.dispersion[item][par] = value
464                            return
465        else:
466            # Look for standard parameter
467            for item in self.params.keys():
468                if item == name:
469                    self.params[item] = value
470                    return
471
472        raise ValueError("Model does not contain parameter %s" % name)
473
474    def getParam(self, name):
475        # type: (str) -> float
476        """
477        Set the value of a model parameter
478
479        :param name: name of the parameter
480
481        """
482        # Look for dispersion parameters
483        toks = name.split('.')
484        if len(toks) == 2:
485            for item in self.dispersion.keys():
486                if item == toks[0]:
487                    for par in self.dispersion[item]:
488                        if par == toks[1]:
489                            return self.dispersion[item][par]
490        else:
491            # Look for standard parameter
492            for item in self.params.keys():
493                if item == name:
494                    return self.params[item]
495
496        raise ValueError("Model does not contain parameter %s" % name)
497
498    def getParamList(self):
499        # type: () -> Sequence[str]
500        """
501        Return a list of all available parameters for the model
502        """
503        param_list = list(self.params.keys())
504        # WARNING: Extending the list with the dispersion parameters
505        param_list.extend(self.getDispParamList())
506        return param_list
507
508    def getDispParamList(self):
509        # type: () -> Sequence[str]
510        """
511        Return a list of polydispersity parameters for the model
512        """
513        # TODO: fix test so that parameter order doesn't matter
514        ret = ['%s.%s' % (p_name, ext)
515               for p_name in self.dispersion.keys()
516               for ext in ('npts', 'nsigmas', 'width')]
517        #print(ret)
518        return ret
519
520    def clone(self):
521        # type: () -> "SasviewModel"
522        """ Return a identical copy of self """
523        return deepcopy(self)
524
525    def run(self, x=0.0):
526        # type: (Union[float, (float, float), List[float]]) -> float
527        """
528        Evaluate the model
529
530        :param x: input q, or [q,phi]
531
532        :return: scattering function P(q)
533
534        **DEPRECATED**: use calculate_Iq instead
535        """
536        if isinstance(x, (list, tuple)):
537            # pylint: disable=unpacking-non-sequence
538            q, phi = x
539            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
540        else:
541            return self.calculate_Iq([x])[0]
542
543
544    def runXY(self, x=0.0):
545        # type: (Union[float, (float, float), List[float]]) -> float
546        """
547        Evaluate the model in cartesian coordinates
548
549        :param x: input q, or [qx, qy]
550
551        :return: scattering function P(q)
552
553        **DEPRECATED**: use calculate_Iq instead
554        """
555        if isinstance(x, (list, tuple)):
556            return self.calculate_Iq([x[0]], [x[1]])[0]
557        else:
558            return self.calculate_Iq([x])[0]
559
560    def evalDistribution(self, qdist):
561        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
562        r"""
563        Evaluate a distribution of q-values.
564
565        :param qdist: array of q or a list of arrays [qx,qy]
566
567        * For 1D, a numpy array is expected as input
568
569        ::
570
571            evalDistribution(q)
572
573          where *q* is a numpy array.
574
575        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
576
577        ::
578
579              qx = [ qx[0], qx[1], qx[2], ....]
580              qy = [ qy[0], qy[1], qy[2], ....]
581
582        If the model is 1D only, then
583
584        .. math::
585
586            q = \sqrt{q_x^2+q_y^2}
587
588        """
589        if isinstance(qdist, (list, tuple)):
590            # Check whether we have a list of ndarrays [qx,qy]
591            qx, qy = qdist
592            if not self._model_info.parameters.has_2d:
593                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
594            else:
595                return self.calculate_Iq(qx, qy)
596
597        elif isinstance(qdist, np.ndarray):
598            # We have a simple 1D distribution of q-values
599            return self.calculate_Iq(qdist)
600
601        else:
602            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
603                            % type(qdist))
604
605    def calc_composition_models(self, qx):
606        """
607        returns parts of the composition model or None if not a composition
608        model.
609        """
610        # TODO: have calculate_Iq return the intermediates.
611        #
612        # The current interface causes calculate_Iq() to be called twice,
613        # once to get the combined result and again to get the intermediate
614        # results.  This is necessary for now.
615        # Long term, the solution is to change the interface to calculate_Iq
616        # so that it returns a results object containing all the bits:
617        #     the A, B, C, ... of the composition model (and any subcomponents?)
618        #     the P and S of the product model,
619        #     the combined model before resolution smearing,
620        #     the sasmodel before sesans conversion,
621        #     the oriented 2D model used to fit oriented usans data,
622        #     the final I(q),
623        #     ...
624        #
625        # Have the model calculator add all of these blindly to the data
626        # tree, and update the graphs which contain them.  The fitter
627        # needs to be updated to use the I(q) value only, ignoring the rest.
628        #
629        # The simple fix of returning the existing intermediate results
630        # will not work for a couple of reasons: (1) another thread may
631        # sneak in to compute its own results before calc_composition_models
632        # is called, and (2) calculate_Iq is currently called three times:
633        # once with q, once with q values before qmin and once with q values
634        # after q max.  Both of these should be addressed before
635        # replacing this code.
636        composition = self._model_info.composition
637        if composition and composition[0] == 'product': # only P*S for now
638            with calculation_lock:
639                self._calculate_Iq(qx)
640                return self._intermediate_results
641        else:
642            return None
643
644    def calculate_Iq(self, qx, qy=None):
645        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
646        """
647        Calculate Iq for one set of q with the current parameters.
648
649        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
650
651        This should NOT be used for fitting since it copies the *q* vectors
652        to the card for each evaluation.
653        """
654        ## uncomment the following when trying to debug the uncoordinated calls
655        ## to calculate_Iq
656        #if calculation_lock.locked():
657        #    logger.info("calculation waiting for another thread to complete")
658        #    logger.info("\n".join(traceback.format_stack()))
659
660        with calculation_lock:
661            return self._calculate_Iq(qx, qy)
662
663    def _calculate_Iq(self, qx, qy=None):
664        #core.HAVE_OPENCL = False
665        if self._model is None:
666            self._model = core.build_model(self._model_info)
667        if qy is not None:
668            q_vectors = [np.asarray(qx), np.asarray(qy)]
669        else:
670            q_vectors = [np.asarray(qx)]
671        calculator = self._model.make_kernel(q_vectors)
672        parameters = self._model_info.parameters
673        pairs = [self._get_weights(p) for p in parameters.call_parameters]
674        #weights.plot_weights(self._model_info, pairs)
675        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
676        #call_details.show()
677        #print("pairs", pairs)
678        #for k, p in enumerate(self._model_info.parameters.call_parameters):
679        #    print(k, p.name, *pairs[k])
680        #print("params", self.params)
681        #print("values", values)
682        #print("is_mag", is_magnetic)
683        result = calculator(call_details, values, cutoff=self.cutoff,
684                            magnetic=is_magnetic)
685        #print("result", result)
686        self._intermediate_results = getattr(calculator, 'results', None)
687        calculator.release()
688        self._model.release()
689        return result
690
691    def calculate_ER(self):
692        # type: () -> float
693        """
694        Calculate the effective radius for P(q)*S(q)
695
696        :return: the value of the effective radius
697        """
698        if self._model_info.ER is None:
699            return 1.0
700        else:
701            value, weight = self._dispersion_mesh()
702            fv = self._model_info.ER(*value)
703            #print(values[0].shape, weights.shape, fv.shape)
704            return np.sum(weight * fv) / np.sum(weight)
705
706    def calculate_VR(self):
707        # type: () -> float
708        """
709        Calculate the volf ratio for P(q)*S(q)
710
711        :return: the value of the volf ratio
712        """
713        if self._model_info.VR is None:
714            return 1.0
715        else:
716            value, weight = self._dispersion_mesh()
717            whole, part = self._model_info.VR(*value)
718            return np.sum(weight * part) / np.sum(weight * whole)
719
720    def set_dispersion(self, parameter, dispersion):
721        # type: (str, weights.Dispersion) -> Dict[str, Any]
722        """
723        Set the dispersion object for a model parameter
724
725        :param parameter: name of the parameter [string]
726        :param dispersion: dispersion object of type Dispersion
727        """
728        if parameter in self.params:
729            # TODO: Store the disperser object directly in the model.
730            # The current method of relying on the sasview GUI to
731            # remember them is kind of funky.
732            # Note: can't seem to get disperser parameters from sasview
733            # (1) Could create a sasview model that has not yet been
734            # converted, assign the disperser to one of its polydisperse
735            # parameters, then retrieve the disperser parameters from the
736            # sasview model.
737            # (2) Could write a disperser parameter retriever in sasview.
738            # (3) Could modify sasview to use sasmodels.weights dispersers.
739            # For now, rely on the fact that the sasview only ever uses
740            # new dispersers in the set_dispersion call and create a new
741            # one instead of trying to assign parameters.
742            self.dispersion[parameter] = dispersion.get_pars()
743        else:
744            raise ValueError("%r is not a dispersity or orientation parameter")
745
746    def _dispersion_mesh(self):
747        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
748        """
749        Create a mesh grid of dispersion parameters and weights.
750
751        Returns [p1,p2,...],w where pj is a vector of values for parameter j
752        and w is a vector containing the products for weights for each
753        parameter set in the vector.
754        """
755        pars = [self._get_weights(p)
756                for p in self._model_info.parameters.call_parameters
757                if p.type == 'volume']
758        return dispersion_mesh(self._model_info, pars)
759
760    def _get_weights(self, par):
761        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
762        """
763        Return dispersion weights for parameter
764        """
765        if par.name not in self.params:
766            if par.name == self.multiplicity_info.control:
767                return self.multiplicity, [self.multiplicity], [1.0]
768            else:
769                # For hidden parameters use default values.  This sets
770                # scale=1 and background=0 for structure factors
771                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
772                return default, [default], [1.0]
773        elif par.polydisperse:
774            value = self.params[par.name]
775            dis = self.dispersion[par.name]
776            if dis['type'] == 'array':
777                dispersity, weight = dis['values'], dis['weights']
778            else:
779                dispersity, weight = weights.get_weights(
780                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
781                    value, par.limits, par.relative_pd)
782            return value, dispersity, weight
783        else:
784            value = self.params[par.name]
785            return value, [value], [1.0]
786
787def test_cylinder():
788    # type: () -> float
789    """
790    Test that the cylinder model runs, returning the value at [0.1,0.1].
791    """
792    Cylinder = _make_standard_model('cylinder')
793    cylinder = Cylinder()
794    return cylinder.evalDistribution([0.1, 0.1])
795
796def test_structure_factor():
797    # type: () -> float
798    """
799    Test that 2-D hardsphere model runs and doesn't produce NaN.
800    """
801    Model = _make_standard_model('hardsphere')
802    model = Model()
803    value2d = model.evalDistribution([0.1, 0.1])
804    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
805    #print("hardsphere", value1d, value2d)
806    if np.isnan(value1d) or np.isnan(value2d):
807        raise ValueError("hardsphere returns nan")
808
809def test_product():
810    # type: () -> float
811    """
812    Test that 2-D hardsphere model runs and doesn't produce NaN.
813    """
814    S = _make_standard_model('hayter_msa')()
815    P = _make_standard_model('cylinder')()
816    model = MultiplicationModel(P, S)
817    value = model.evalDistribution([0.1, 0.1])
818    if np.isnan(value):
819        raise ValueError("cylinder*hatyer_msa returns null")
820
821def test_rpa():
822    # type: () -> float
823    """
824    Test that the 2-D RPA model runs
825    """
826    RPA = _make_standard_model('rpa')
827    rpa = RPA(3)
828    return rpa.evalDistribution([0.1, 0.1])
829
830def test_empty_distribution():
831    # type: () -> None
832    """
833    Make sure that sasmodels returns NaN when there are no polydispersity points
834    """
835    Cylinder = _make_standard_model('cylinder')
836    cylinder = Cylinder()
837    cylinder.setParam('radius', -1.0)
838    cylinder.setParam('background', 0.)
839    Iq = cylinder.evalDistribution(np.asarray([0.1]))
840    assert np.isnan(Iq[0]), "empty distribution fails"
841
842def test_model_list():
843    # type: () -> None
844    """
845    Make sure that all models build as sasview models
846    """
847    from .exception import annotate_exception
848    for name in core.list_models():
849        try:
850            _make_standard_model(name)
851        except:
852            annotate_exception("when loading "+name)
853            raise
854
855def test_old_name():
856    # type: () -> None
857    """
858    Load and run cylinder model from sas.models.CylinderModel
859    """
860    if not SUPPORT_OLD_STYLE_PLUGINS:
861        return
862    try:
863        # if sasview is not on the path then don't try to test it
864        import sas
865    except ImportError:
866        return
867    load_standard_models()
868    from sas.models.CylinderModel import CylinderModel
869    CylinderModel().evalDistribution([0.1, 0.1])
870
871if __name__ == "__main__":
872    print("cylinder(0.1,0.1)=%g"%test_cylinder())
873    #test_product()
874    #test_structure_factor()
875    #print("rpa:", test_rpa())
876    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.