source: sasmodels/sasmodels/sasview_model.py @ bd36af0

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since bd36af0 was edb0f85, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

tweak implementation of sld profile for product models

  • Property mode set to 100644
File size: 29.8 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33try:
34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
38        'MuliplicityInfo',
39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
41    SasviewModelType = Callable[[int], "SasviewModel"]
42except ImportError:
43    pass
44
45logger = logging.getLogger(__name__)
46
47calculation_lock = thread.allocate_lock()
48
49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
51SUPPORT_OLD_STYLE_PLUGINS = True
52
53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
64def find_model(modelname):
65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
79
80# TODO: figure out how to say that the return type is a subclass
81def load_standard_models():
82    # type: () -> List[SasviewModelType]
83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
91            MODELS[name] = _make_standard_model(name)
92        except Exception:
93            logger.error(traceback.format_exc())
94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
97    return list(MODELS.values())
98
99
100def load_custom_model(path):
101    # type: (str) -> SasviewModelType
102    """
103    Load a custom model given the model path.
104    """
105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
111    kernel_module = custom.load_custom_kernel_module(path)
112    if hasattr(kernel_module, 'Model'):
113        model = kernel_module.Model
114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
119        if not hasattr(model, 'filename'):
120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
123    else:
124        model_info = modelinfo.make_model_info(kernel_module)
125        model = make_model_from_info(model_info)
126    model.timestamp = getmtime(path)
127
128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
133
134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
141
142    MODELS[model.name] = model
143    MODEL_BY_PATH[path] = model
144    return model
145
146
147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
161def _make_standard_model(name):
162    # type: (str) -> SasviewModelType
163    """
164    Load the sasview model defined by *name*.
165
166    *name* can be a standard model name or a path to a custom model.
167
168    Returns a class that can be used directly as a sasview model.
169    """
170    kernel_module = generate.load_kernel_module(name)
171    model_info = modelinfo.make_model_info(kernel_module)
172    return make_model_from_info(model_info)
173
174
175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188
189    import sas.models
190    from sasmodels.conversion_table import CONVERSION_TABLE
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
206    ConstructedModel = make_model_from_info(model_info)
207    return ConstructedModel(form_factor.multiplicity)
208
209
210def _generate_model_attributes(model_info):
211    # type: (ModelInfo) -> Dict[str, Any]
212    """
213    Generate the class attributes for the model.
214
215    This should include all the information necessary to query the model
216    details so that you do not need to instantiate a model to query it.
217
218    All the attributes should be immutable to avoid accidents.
219    """
220
221    # TODO: allow model to override axis labels input/output name/unit
222
223    # Process multiplicity
224    non_fittable = []  # type: List[str]
225    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
226    variants = MultiplicityInfo(0, "", [], xlabel)
227    for p in model_info.parameters.kernel_parameters:
228        if p.name == model_info.control:
229            non_fittable.append(p.name)
230            variants = MultiplicityInfo(
231                len(p.choices) if p.choices else int(p.limits[1]),
232                p.name, p.choices, xlabel
233            )
234            break
235
236    # Only a single drop-down list parameter available
237    fun_list = []
238    for p in model_info.parameters.kernel_parameters:
239        if p.choices:
240            fun_list = p.choices
241            if p.length > 1:
242                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
243            break
244
245    # Organize parameter sets
246    orientation_params = []
247    magnetic_params = []
248    fixed = []
249    for p in model_info.parameters.user_parameters({}, is2d=True):
250        if p.type == 'orientation':
251            orientation_params.append(p.name)
252            orientation_params.append(p.name+".width")
253            fixed.append(p.name+".width")
254        elif p.type == 'magnetic':
255            orientation_params.append(p.name)
256            magnetic_params.append(p.name)
257            fixed.append(p.name+".width")
258
259
260    # Build class dictionary
261    attrs = {}  # type: Dict[str, Any]
262    attrs['_model_info'] = model_info
263    attrs['name'] = model_info.name
264    attrs['id'] = model_info.id
265    attrs['description'] = model_info.description
266    attrs['category'] = model_info.category
267    attrs['is_structure_factor'] = model_info.structure_factor
268    attrs['is_form_factor'] = model_info.ER is not None
269    attrs['is_multiplicity_model'] = variants[0] > 1
270    attrs['multiplicity_info'] = variants
271    attrs['orientation_params'] = tuple(orientation_params)
272    attrs['magnetic_params'] = tuple(magnetic_params)
273    attrs['fixed'] = tuple(fixed)
274    attrs['non_fittable'] = tuple(non_fittable)
275    attrs['fun_list'] = tuple(fun_list)
276
277    return attrs
278
279class SasviewModel(object):
280    """
281    Sasview wrapper for opencl/ctypes model.
282    """
283    # Model parameters for the specific model are set in the class constructor
284    # via the _generate_model_attributes function, which subclasses
285    # SasviewModel.  They are included here for typing and documentation
286    # purposes.
287    _model = None       # type: KernelModel
288    _model_info = None  # type: ModelInfo
289    #: load/save name for the model
290    id = None           # type: str
291    #: display name for the model
292    name = None         # type: str
293    #: short model description
294    description = None  # type: str
295    #: default model category
296    category = None     # type: str
297
298    #: names of the orientation parameters in the order they appear
299    orientation_params = None # type: List[str]
300    #: names of the magnetic parameters in the order they appear
301    magnetic_params = None    # type: List[str]
302    #: names of the fittable parameters
303    fixed = None              # type: List[str]
304    # TODO: the attribute fixed is ill-named
305
306    # Axis labels
307    input_name = "Q"
308    input_unit = "A^{-1}"
309    output_name = "Intensity"
310    output_unit = "cm^{-1}"
311
312    #: default cutoff for polydispersity
313    cutoff = 1e-5
314
315    # Note: Use non-mutable values for class attributes to avoid errors
316    #: parameters that are not fitted
317    non_fittable = ()        # type: Sequence[str]
318
319    #: True if model should appear as a structure factor
320    is_structure_factor = False
321    #: True if model should appear as a form factor
322    is_form_factor = False
323    #: True if model has multiplicity
324    is_multiplicity_model = False
325    #: Multiplicity information
326    multiplicity_info = None # type: MultiplicityInfoType
327
328    # Per-instance variables
329    #: parameter {name: value} mapping
330    params = None      # type: Dict[str, float]
331    #: values for dispersion width, npts, nsigmas and type
332    dispersion = None  # type: Dict[str, Any]
333    #: units and limits for each parameter
334    details = None     # type: Dict[str, Sequence[Any]]
335    #                  # actual type is Dict[str, List[str, float, float]]
336    #: multiplicity value, or None if no multiplicity on the model
337    multiplicity = None     # type: Optional[int]
338    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
339    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
340
341    def __init__(self, multiplicity=None):
342        # type: (Optional[int]) -> None
343
344        # TODO: _persistency_dict to persistency_dict throughout sasview
345        # TODO: refactor multiplicity to encompass variants
346        # TODO: dispersion should be a class
347        # TODO: refactor multiplicity info
348        # TODO: separate profile view from multiplicity
349        # The button label, x and y axis labels and scale need to be under
350        # the control of the model, not the fit page.  Maximum flexibility,
351        # the fit page would supply the canvas and the profile could plot
352        # how it wants, but this assumes matplotlib.  Next level is that
353        # we provide some sort of data description including title, labels
354        # and lines to plot.
355
356        # Get the list of hidden parameters given the multiplicity
357        # Don't include multiplicity in the list of parameters
358        self.multiplicity = multiplicity
359        if multiplicity is not None:
360            hidden = self._model_info.get_hidden_parameters(multiplicity)
361            hidden |= set([self.multiplicity_info.control])
362        else:
363            hidden = set()
364        if self._model_info.structure_factor:
365            hidden.add('scale')
366            hidden.add('background')
367            self._model_info.parameters.defaults['background'] = 0.
368
369        self._persistency_dict = {}
370        self.params = collections.OrderedDict()
371        self.dispersion = collections.OrderedDict()
372        self.details = {}
373        for p in self._model_info.parameters.user_parameters({}, is2d=True):
374            if p.name in hidden:
375                continue
376            self.params[p.name] = p.default
377            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
378            if p.polydisperse:
379                self.details[p.id+".width"] = [
380                    "", 0.0, 1.0 if p.relative_pd else np.inf
381                ]
382                self.dispersion[p.name] = {
383                    'width': 0,
384                    'npts': 35,
385                    'nsigmas': 3,
386                    'type': 'gaussian',
387                }
388
389    def __get_state__(self):
390        # type: () -> Dict[str, Any]
391        state = self.__dict__.copy()
392        state.pop('_model')
393        # May need to reload model info on set state since it has pointers
394        # to python implementations of Iq, etc.
395        #state.pop('_model_info')
396        return state
397
398    def __set_state__(self, state):
399        # type: (Dict[str, Any]) -> None
400        self.__dict__ = state
401        self._model = None
402
403    def __str__(self):
404        # type: () -> str
405        """
406        :return: string representation
407        """
408        return self.name
409
410    def is_fittable(self, par_name):
411        # type: (str) -> bool
412        """
413        Check if a given parameter is fittable or not
414
415        :param par_name: the parameter name to check
416        """
417        return par_name in self.fixed
418        #For the future
419        #return self.params[str(par_name)].is_fittable()
420
421
422    def getProfile(self):
423        # type: () -> (np.ndarray, np.ndarray)
424        """
425        Get SLD profile
426
427        : return: (z, beta) where z is a list of depth of the transition points
428                beta is a list of the corresponding SLD values
429        """
430        args = {} # type: Dict[str, Any]
431        for p in self._model_info.parameters.kernel_parameters:
432            if p.id == self.multiplicity_info.control:
433                value = float(self.multiplicity)
434            elif p.length == 1:
435                value = self.params.get(p.id, np.NaN)
436            else:
437                value = np.array([self.params.get(p.id+str(k), np.NaN)
438                                  for k in range(1, p.length+1)])
439            args[p.id] = value
440
441        x, y = self._model_info.profile(**args)
442        return x, 1e-6*y
443
444    def setParam(self, name, value):
445        # type: (str, float) -> None
446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450        :param value: value of the parameter
451
452        """
453        # Look for dispersion parameters
454        toks = name.split('.')
455        if len(toks) == 2:
456            for item in self.dispersion.keys():
457                if item == toks[0]:
458                    for par in self.dispersion[item]:
459                        if par == toks[1]:
460                            self.dispersion[item][par] = value
461                            return
462        else:
463            # Look for standard parameter
464            for item in self.params.keys():
465                if item == name:
466                    self.params[item] = value
467                    return
468
469        raise ValueError("Model does not contain parameter %s" % name)
470
471    def getParam(self, name):
472        # type: (str) -> float
473        """
474        Set the value of a model parameter
475
476        :param name: name of the parameter
477
478        """
479        # Look for dispersion parameters
480        toks = name.split('.')
481        if len(toks) == 2:
482            for item in self.dispersion.keys():
483                if item == toks[0]:
484                    for par in self.dispersion[item]:
485                        if par == toks[1]:
486                            return self.dispersion[item][par]
487        else:
488            # Look for standard parameter
489            for item in self.params.keys():
490                if item == name:
491                    return self.params[item]
492
493        raise ValueError("Model does not contain parameter %s" % name)
494
495    def getParamList(self):
496        # type: () -> Sequence[str]
497        """
498        Return a list of all available parameters for the model
499        """
500        param_list = list(self.params.keys())
501        # WARNING: Extending the list with the dispersion parameters
502        param_list.extend(self.getDispParamList())
503        return param_list
504
505    def getDispParamList(self):
506        # type: () -> Sequence[str]
507        """
508        Return a list of polydispersity parameters for the model
509        """
510        # TODO: fix test so that parameter order doesn't matter
511        ret = ['%s.%s' % (p_name, ext)
512               for p_name in self.dispersion.keys()
513               for ext in ('npts', 'nsigmas', 'width')]
514        #print(ret)
515        return ret
516
517    def clone(self):
518        # type: () -> "SasviewModel"
519        """ Return a identical copy of self """
520        return deepcopy(self)
521
522    def run(self, x=0.0):
523        # type: (Union[float, (float, float), List[float]]) -> float
524        """
525        Evaluate the model
526
527        :param x: input q, or [q,phi]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
533        if isinstance(x, (list, tuple)):
534            # pylint: disable=unpacking-non-sequence
535            q, phi = x
536            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
537        else:
538            return self.calculate_Iq([x])[0]
539
540
541    def runXY(self, x=0.0):
542        # type: (Union[float, (float, float), List[float]]) -> float
543        """
544        Evaluate the model in cartesian coordinates
545
546        :param x: input q, or [qx, qy]
547
548        :return: scattering function P(q)
549
550        **DEPRECATED**: use calculate_Iq instead
551        """
552        if isinstance(x, (list, tuple)):
553            return self.calculate_Iq([x[0]], [x[1]])[0]
554        else:
555            return self.calculate_Iq([x])[0]
556
557    def evalDistribution(self, qdist):
558        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
559        r"""
560        Evaluate a distribution of q-values.
561
562        :param qdist: array of q or a list of arrays [qx,qy]
563
564        * For 1D, a numpy array is expected as input
565
566        ::
567
568            evalDistribution(q)
569
570          where *q* is a numpy array.
571
572        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
573
574        ::
575
576              qx = [ qx[0], qx[1], qx[2], ....]
577              qy = [ qy[0], qy[1], qy[2], ....]
578
579        If the model is 1D only, then
580
581        .. math::
582
583            q = \sqrt{q_x^2+q_y^2}
584
585        """
586        if isinstance(qdist, (list, tuple)):
587            # Check whether we have a list of ndarrays [qx,qy]
588            qx, qy = qdist
589            if not self._model_info.parameters.has_2d:
590                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
591            else:
592                return self.calculate_Iq(qx, qy)
593
594        elif isinstance(qdist, np.ndarray):
595            # We have a simple 1D distribution of q-values
596            return self.calculate_Iq(qdist)
597
598        else:
599            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
600                            % type(qdist))
601
602    def calc_composition_models(self, qx):
603        """
604        returns parts of the composition model or None if not a composition
605        model.
606        """
607        # TODO: have calculate_Iq return the intermediates.
608        #
609        # The current interface causes calculate_Iq() to be called twice,
610        # once to get the combined result and again to get the intermediate
611        # results.  This is necessary for now.
612        # Long term, the solution is to change the interface to calculate_Iq
613        # so that it returns a results object containing all the bits:
614        #     the A, B, C, ... of the composition model (and any subcomponents?)
615        #     the P and S of the product model,
616        #     the combined model before resolution smearing,
617        #     the sasmodel before sesans conversion,
618        #     the oriented 2D model used to fit oriented usans data,
619        #     the final I(q),
620        #     ...
621        #
622        # Have the model calculator add all of these blindly to the data
623        # tree, and update the graphs which contain them.  The fitter
624        # needs to be updated to use the I(q) value only, ignoring the rest.
625        #
626        # The simple fix of returning the existing intermediate results
627        # will not work for a couple of reasons: (1) another thread may
628        # sneak in to compute its own results before calc_composition_models
629        # is called, and (2) calculate_Iq is currently called three times:
630        # once with q, once with q values before qmin and once with q values
631        # after q max.  Both of these should be addressed before
632        # replacing this code.
633        composition = self._model_info.composition
634        if composition and composition[0] == 'product': # only P*S for now
635            with calculation_lock:
636                self._calculate_Iq(qx)
637                return self._intermediate_results
638        else:
639            return None
640
641    def calculate_Iq(self, qx, qy=None):
642        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
643        """
644        Calculate Iq for one set of q with the current parameters.
645
646        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
647
648        This should NOT be used for fitting since it copies the *q* vectors
649        to the card for each evaluation.
650        """
651        ## uncomment the following when trying to debug the uncoordinated calls
652        ## to calculate_Iq
653        #if calculation_lock.locked():
654        #    logger.info("calculation waiting for another thread to complete")
655        #    logger.info("\n".join(traceback.format_stack()))
656
657        with calculation_lock:
658            return self._calculate_Iq(qx, qy)
659
660    def _calculate_Iq(self, qx, qy=None):
661        #core.HAVE_OPENCL = False
662        if self._model is None:
663            self._model = core.build_model(self._model_info)
664        if qy is not None:
665            q_vectors = [np.asarray(qx), np.asarray(qy)]
666        else:
667            q_vectors = [np.asarray(qx)]
668        calculator = self._model.make_kernel(q_vectors)
669        parameters = self._model_info.parameters
670        pairs = [self._get_weights(p) for p in parameters.call_parameters]
671        #weights.plot_weights(self._model_info, pairs)
672        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
673        #call_details.show()
674        #print("pairs", pairs)
675        #print("params", self.params)
676        #print("values", values)
677        #print("is_mag", is_magnetic)
678        result = calculator(call_details, values, cutoff=self.cutoff,
679                            magnetic=is_magnetic)
680        self._intermediate_results = getattr(calculator, 'results', None)
681        calculator.release()
682        self._model.release()
683        return result
684
685    def calculate_ER(self):
686        # type: () -> float
687        """
688        Calculate the effective radius for P(q)*S(q)
689
690        :return: the value of the effective radius
691        """
692        if self._model_info.ER is None:
693            return 1.0
694        else:
695            value, weight = self._dispersion_mesh()
696            fv = self._model_info.ER(*value)
697            #print(values[0].shape, weights.shape, fv.shape)
698            return np.sum(weight * fv) / np.sum(weight)
699
700    def calculate_VR(self):
701        # type: () -> float
702        """
703        Calculate the volf ratio for P(q)*S(q)
704
705        :return: the value of the volf ratio
706        """
707        if self._model_info.VR is None:
708            return 1.0
709        else:
710            value, weight = self._dispersion_mesh()
711            whole, part = self._model_info.VR(*value)
712            return np.sum(weight * part) / np.sum(weight * whole)
713
714    def set_dispersion(self, parameter, dispersion):
715        # type: (str, weights.Dispersion) -> Dict[str, Any]
716        """
717        Set the dispersion object for a model parameter
718
719        :param parameter: name of the parameter [string]
720        :param dispersion: dispersion object of type Dispersion
721        """
722        if parameter in self.params:
723            # TODO: Store the disperser object directly in the model.
724            # The current method of relying on the sasview GUI to
725            # remember them is kind of funky.
726            # Note: can't seem to get disperser parameters from sasview
727            # (1) Could create a sasview model that has not yet been
728            # converted, assign the disperser to one of its polydisperse
729            # parameters, then retrieve the disperser parameters from the
730            # sasview model.
731            # (2) Could write a disperser parameter retriever in sasview.
732            # (3) Could modify sasview to use sasmodels.weights dispersers.
733            # For now, rely on the fact that the sasview only ever uses
734            # new dispersers in the set_dispersion call and create a new
735            # one instead of trying to assign parameters.
736            self.dispersion[parameter] = dispersion.get_pars()
737        else:
738            raise ValueError("%r is not a dispersity or orientation parameter")
739
740    def _dispersion_mesh(self):
741        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
742        """
743        Create a mesh grid of dispersion parameters and weights.
744
745        Returns [p1,p2,...],w where pj is a vector of values for parameter j
746        and w is a vector containing the products for weights for each
747        parameter set in the vector.
748        """
749        pars = [self._get_weights(p)
750                for p in self._model_info.parameters.call_parameters
751                if p.type == 'volume']
752        return dispersion_mesh(self._model_info, pars)
753
754    def _get_weights(self, par):
755        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
756        """
757        Return dispersion weights for parameter
758        """
759        if par.name not in self.params:
760            if par.name == self.multiplicity_info.control:
761                return [self.multiplicity], [1.0]
762            else:
763                # For hidden parameters use the default value.
764                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
765                return [value], [1.0]
766        elif par.polydisperse:
767            dis = self.dispersion[par.name]
768            if dis['type'] == 'array':
769                value, weight = dis['values'], dis['weights']
770            else:
771                value, weight = weights.get_weights(
772                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
773                    self.params[par.name], par.limits, par.relative_pd)
774            return value, weight / np.sum(weight)
775        else:
776            return [self.params[par.name]], [1.0]
777
778def test_cylinder():
779    # type: () -> float
780    """
781    Test that the cylinder model runs, returning the value at [0.1,0.1].
782    """
783    Cylinder = _make_standard_model('cylinder')
784    cylinder = Cylinder()
785    return cylinder.evalDistribution([0.1, 0.1])
786
787def test_structure_factor():
788    # type: () -> float
789    """
790    Test that 2-D hardsphere model runs and doesn't produce NaN.
791    """
792    Model = _make_standard_model('hardsphere')
793    model = Model()
794    value = model.evalDistribution([0.1, 0.1])
795    if np.isnan(value):
796        raise ValueError("hardsphere returns null")
797
798def test_rpa():
799    # type: () -> float
800    """
801    Test that the 2-D RPA model runs
802    """
803    RPA = _make_standard_model('rpa')
804    rpa = RPA(3)
805    return rpa.evalDistribution([0.1, 0.1])
806
807def test_empty_distribution():
808    # type: () -> None
809    """
810    Make sure that sasmodels returns NaN when there are no polydispersity points
811    """
812    Cylinder = _make_standard_model('cylinder')
813    cylinder = Cylinder()
814    cylinder.setParam('radius', -1.0)
815    cylinder.setParam('background', 0.)
816    Iq = cylinder.evalDistribution(np.asarray([0.1]))
817    assert np.isnan(Iq[0]), "empty distribution fails"
818
819def test_model_list():
820    # type: () -> None
821    """
822    Make sure that all models build as sasview models
823    """
824    from .exception import annotate_exception
825    for name in core.list_models():
826        try:
827            _make_standard_model(name)
828        except:
829            annotate_exception("when loading "+name)
830            raise
831
832def test_old_name():
833    # type: () -> None
834    """
835    Load and run cylinder model from sas.models.CylinderModel
836    """
837    if not SUPPORT_OLD_STYLE_PLUGINS:
838        return
839    try:
840        # if sasview is not on the path then don't try to test it
841        import sas
842    except ImportError:
843        return
844    load_standard_models()
845    from sas.models.CylinderModel import CylinderModel
846    CylinderModel().evalDistribution([0.1, 0.1])
847
848if __name__ == "__main__":
849    print("cylinder(0.1,0.1)=%g"%test_cylinder())
850    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.