source: sasmodels/sasmodels/sasview_model.py @ 910c0f4

Last change on this file since 910c0f4 was aa25fc7, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

load user-defined weight functions from ~/.sasview/weights/*.py

  • Property mode set to 100644
File size: 31.4 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33# Hack: load in any custom distributions
34# Uses ~/.sasview/weights/*.py unless SASMODELS_WEIGHTS is set in the environ.
35# Override with weights.load_weights(pattern="<weights_path>/*.py")
36weights.load_weights()
37
38# pylint: disable=unused-import
39try:
40    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
41                        List, Optional, Union, Callable)
42    from .modelinfo import ModelInfo, Parameter
43    from .kernel import KernelModel
44    MultiplicityInfoType = NamedTuple(
45        'MultiplicityInfo',
46        [("number", int), ("control", str), ("choices", List[str]),
47         ("x_axis_label", str)])
48    SasviewModelType = Callable[[int], "SasviewModel"]
49except ImportError:
50    pass
51# pylint: enable=unused-import
52
53logger = logging.getLogger(__name__)
54
55calculation_lock = thread.allocate_lock()
56
57#: True if pre-existing plugins, with the old names and parameters, should
58#: continue to be supported.
59SUPPORT_OLD_STYLE_PLUGINS = True
60
61# TODO: separate x_axis_label from multiplicity info
62MultiplicityInfo = collections.namedtuple(
63    'MultiplicityInfo',
64    ["number", "control", "choices", "x_axis_label"],
65)
66
67#: set of defined models (standard and custom)
68MODELS = {}  # type: Dict[str, SasviewModelType]
69#: custom model {path: model} mapping so we can check timestamps
70MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
71
72def find_model(modelname):
73    # type: (str) -> SasviewModelType
74    """
75    Find a model by name.  If the model name ends in py, try loading it from
76    custom models, otherwise look for it in the list of builtin models.
77    """
78    # TODO: used by sum/product model to load an existing model
79    # TODO: doesn't handle custom models properly
80    if modelname.endswith('.py'):
81        return load_custom_model(modelname)
82    elif modelname in MODELS:
83        return MODELS[modelname]
84    else:
85        raise ValueError("unknown model %r"%modelname)
86
87
88# TODO: figure out how to say that the return type is a subclass
89def load_standard_models():
90    # type: () -> List[SasviewModelType]
91    """
92    Load and return the list of predefined models.
93
94    If there is an error loading a model, then a traceback is logged and the
95    model is not returned.
96    """
97    for name in core.list_models():
98        try:
99            MODELS[name] = _make_standard_model(name)
100        except Exception:
101            logger.error(traceback.format_exc())
102    if SUPPORT_OLD_STYLE_PLUGINS:
103        _register_old_models()
104
105    return list(MODELS.values())
106
107
108def load_custom_model(path):
109    # type: (str) -> SasviewModelType
110    """
111    Load a custom model given the model path.
112    """
113    model = MODEL_BY_PATH.get(path, None)
114    if model is not None and model.timestamp == getmtime(path):
115        #logger.info("Model already loaded %s", path)
116        return model
117
118    #logger.info("Loading model %s", path)
119    kernel_module = custom.load_custom_kernel_module(path)
120    if hasattr(kernel_module, 'Model'):
121        model = kernel_module.Model
122        # Old style models do not set the name in the class attributes, so
123        # set it here; this name will be overridden when the object is created
124        # with an instance variable that has the same value.
125        if model.name == "":
126            model.name = splitext(basename(path))[0]
127        if not hasattr(model, 'filename'):
128            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
129        if not hasattr(model, 'id'):
130            model.id = splitext(basename(model.filename))[0]
131    else:
132        model_info = modelinfo.make_model_info(kernel_module)
133        model = make_model_from_info(model_info)
134    model.timestamp = getmtime(path)
135
136    # If a model name already exists and we are loading a different model,
137    # use the model file name as the model name.
138    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
139        _previous_name = model.name
140        model.name = model.id
141
142        # If the new model name is still in the model list (for instance,
143        # if we put a cylinder.py in our plug-in directory), then append
144        # an identifier.
145        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
146            model.name = model.id + '_user'
147        logger.info("Model %s already exists: using %s [%s]",
148                    _previous_name, model.name, model.filename)
149
150    MODELS[model.name] = model
151    MODEL_BY_PATH[path] = model
152    return model
153
154
155def make_model_from_info(model_info):
156    # type: (ModelInfo) -> SasviewModelType
157    """
158    Convert *model_info* into a SasView model wrapper.
159    """
160    def __init__(self, multiplicity=None):
161        SasviewModel.__init__(self, multiplicity=multiplicity)
162    attrs = _generate_model_attributes(model_info)
163    attrs['__init__'] = __init__
164    attrs['filename'] = model_info.filename
165    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
166    return ConstructedModel
167
168
169def _make_standard_model(name):
170    # type: (str) -> SasviewModelType
171    """
172    Load the sasview model defined by *name*.
173
174    *name* can be a standard model name or a path to a custom model.
175
176    Returns a class that can be used directly as a sasview model.
177    """
178    kernel_module = generate.load_kernel_module(name)
179    model_info = modelinfo.make_model_info(kernel_module)
180    return make_model_from_info(model_info)
181
182
183def _register_old_models():
184    # type: () -> None
185    """
186    Place the new models into sasview under the old names.
187
188    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
189    is available to the plugin modules.
190    """
191    import sys
192    import sas   # needed in order to set sas.models
193    import sas.sascalc.fit
194    sys.modules['sas.models'] = sas.sascalc.fit
195    sas.models = sas.sascalc.fit
196    import sas.models
197    from sasmodels.conversion_table import CONVERSION_TABLE
198
199    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
200        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
201        new_name = new_name.split(':')[0]
202        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
203        module_attrs = {old_name: find_model(new_name)}
204        ConstructedModule = type(old_name, (), module_attrs)
205        old_path = 'sas.models.' + old_name
206        setattr(sas.models, old_path, ConstructedModule)
207        sys.modules[old_path] = ConstructedModule
208
209
210def MultiplicationModel(form_factor, structure_factor):
211    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
212    """
213    Returns a constructed product model from form_factor and structure_factor.
214    """
215    model_info = product.make_product_info(form_factor._model_info,
216                                           structure_factor._model_info)
217    ConstructedModel = make_model_from_info(model_info)
218    return ConstructedModel(form_factor.multiplicity)
219
220
221def _generate_model_attributes(model_info):
222    # type: (ModelInfo) -> Dict[str, Any]
223    """
224    Generate the class attributes for the model.
225
226    This should include all the information necessary to query the model
227    details so that you do not need to instantiate a model to query it.
228
229    All the attributes should be immutable to avoid accidents.
230    """
231
232    # TODO: allow model to override axis labels input/output name/unit
233
234    # Process multiplicity
235    non_fittable = []  # type: List[str]
236    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
237    variants = MultiplicityInfo(0, "", [], xlabel)
238    for p in model_info.parameters.kernel_parameters:
239        if p.name == model_info.control:
240            non_fittable.append(p.name)
241            variants = MultiplicityInfo(
242                len(p.choices) if p.choices else int(p.limits[1]),
243                p.name, p.choices, xlabel
244            )
245            break
246
247    # Only a single drop-down list parameter available
248    fun_list = []
249    for p in model_info.parameters.kernel_parameters:
250        if p.choices:
251            fun_list = p.choices
252            if p.length > 1:
253                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
254            break
255
256    # Organize parameter sets
257    orientation_params = []
258    magnetic_params = []
259    fixed = []
260    for p in model_info.parameters.user_parameters({}, is2d=True):
261        if p.type == 'orientation':
262            orientation_params.append(p.name)
263            orientation_params.append(p.name+".width")
264            fixed.append(p.name+".width")
265        elif p.type == 'magnetic':
266            orientation_params.append(p.name)
267            magnetic_params.append(p.name)
268            fixed.append(p.name+".width")
269
270
271    # Build class dictionary
272    attrs = {}  # type: Dict[str, Any]
273    attrs['_model_info'] = model_info
274    attrs['name'] = model_info.name
275    attrs['id'] = model_info.id
276    attrs['description'] = model_info.description
277    attrs['category'] = model_info.category
278    attrs['is_structure_factor'] = model_info.structure_factor
279    attrs['is_form_factor'] = model_info.ER is not None
280    attrs['is_multiplicity_model'] = variants[0] > 1
281    attrs['multiplicity_info'] = variants
282    attrs['orientation_params'] = tuple(orientation_params)
283    attrs['magnetic_params'] = tuple(magnetic_params)
284    attrs['fixed'] = tuple(fixed)
285    attrs['non_fittable'] = tuple(non_fittable)
286    attrs['fun_list'] = tuple(fun_list)
287
288    return attrs
289
290class SasviewModel(object):
291    """
292    Sasview wrapper for opencl/ctypes model.
293    """
294    # Model parameters for the specific model are set in the class constructor
295    # via the _generate_model_attributes function, which subclasses
296    # SasviewModel.  They are included here for typing and documentation
297    # purposes.
298    _model = None       # type: KernelModel
299    _model_info = None  # type: ModelInfo
300    #: load/save name for the model
301    id = None           # type: str
302    #: display name for the model
303    name = None         # type: str
304    #: short model description
305    description = None  # type: str
306    #: default model category
307    category = None     # type: str
308
309    #: names of the orientation parameters in the order they appear
310    orientation_params = None # type: List[str]
311    #: names of the magnetic parameters in the order they appear
312    magnetic_params = None    # type: List[str]
313    #: names of the fittable parameters
314    fixed = None              # type: List[str]
315    # TODO: the attribute fixed is ill-named
316
317    # Axis labels
318    input_name = "Q"
319    input_unit = "A^{-1}"
320    output_name = "Intensity"
321    output_unit = "cm^{-1}"
322
323    #: default cutoff for polydispersity
324    cutoff = 1e-5
325
326    # Note: Use non-mutable values for class attributes to avoid errors
327    #: parameters that are not fitted
328    non_fittable = ()        # type: Sequence[str]
329
330    #: True if model should appear as a structure factor
331    is_structure_factor = False
332    #: True if model should appear as a form factor
333    is_form_factor = False
334    #: True if model has multiplicity
335    is_multiplicity_model = False
336    #: Multiplicity information
337    multiplicity_info = None # type: MultiplicityInfoType
338
339    # Per-instance variables
340    #: parameter {name: value} mapping
341    params = None      # type: Dict[str, float]
342    #: values for dispersion width, npts, nsigmas and type
343    dispersion = None  # type: Dict[str, Any]
344    #: units and limits for each parameter
345    details = None     # type: Dict[str, Sequence[Any]]
346    #                  # actual type is Dict[str, List[str, float, float]]
347    #: multiplicity value, or None if no multiplicity on the model
348    multiplicity = None     # type: Optional[int]
349    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
350    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
351
352    def __init__(self, multiplicity=None):
353        # type: (Optional[int]) -> None
354
355        # TODO: _persistency_dict to persistency_dict throughout sasview
356        # TODO: refactor multiplicity to encompass variants
357        # TODO: dispersion should be a class
358        # TODO: refactor multiplicity info
359        # TODO: separate profile view from multiplicity
360        # The button label, x and y axis labels and scale need to be under
361        # the control of the model, not the fit page.  Maximum flexibility,
362        # the fit page would supply the canvas and the profile could plot
363        # how it wants, but this assumes matplotlib.  Next level is that
364        # we provide some sort of data description including title, labels
365        # and lines to plot.
366
367        # Get the list of hidden parameters given the multiplicity
368        # Don't include multiplicity in the list of parameters
369        self.multiplicity = multiplicity
370        if multiplicity is not None:
371            hidden = self._model_info.get_hidden_parameters(multiplicity)
372            hidden |= set([self.multiplicity_info.control])
373        else:
374            hidden = set()
375        if self._model_info.structure_factor:
376            hidden.add('scale')
377            hidden.add('background')
378            self._model_info.parameters.defaults['background'] = 0.
379
380        self._persistency_dict = {}
381        self.params = collections.OrderedDict()
382        self.dispersion = collections.OrderedDict()
383        self.details = {}
384        for p in self._model_info.parameters.user_parameters({}, is2d=True):
385            if p.name in hidden:
386                continue
387            self.params[p.name] = p.default
388            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
389            if p.polydisperse:
390                self.details[p.id+".width"] = [
391                    "", 0.0, 1.0 if p.relative_pd else np.inf
392                ]
393                self.dispersion[p.name] = {
394                    'width': 0,
395                    'npts': 35,
396                    'nsigmas': 3,
397                    'type': 'gaussian',
398                }
399
400    def __get_state__(self):
401        # type: () -> Dict[str, Any]
402        state = self.__dict__.copy()
403        state.pop('_model')
404        # May need to reload model info on set state since it has pointers
405        # to python implementations of Iq, etc.
406        #state.pop('_model_info')
407        return state
408
409    def __set_state__(self, state):
410        # type: (Dict[str, Any]) -> None
411        self.__dict__ = state
412        self._model = None
413
414    def __str__(self):
415        # type: () -> str
416        """
417        :return: string representation
418        """
419        return self.name
420
421    def is_fittable(self, par_name):
422        # type: (str) -> bool
423        """
424        Check if a given parameter is fittable or not
425
426        :param par_name: the parameter name to check
427        """
428        return par_name in self.fixed
429        #For the future
430        #return self.params[str(par_name)].is_fittable()
431
432
433    def getProfile(self):
434        # type: () -> (np.ndarray, np.ndarray)
435        """
436        Get SLD profile
437
438        : return: (z, beta) where z is a list of depth of the transition points
439                beta is a list of the corresponding SLD values
440        """
441        args = {} # type: Dict[str, Any]
442        for p in self._model_info.parameters.kernel_parameters:
443            if p.id == self.multiplicity_info.control:
444                value = float(self.multiplicity)
445            elif p.length == 1:
446                value = self.params.get(p.id, np.NaN)
447            else:
448                value = np.array([self.params.get(p.id+str(k), np.NaN)
449                                  for k in range(1, p.length+1)])
450            args[p.id] = value
451
452        x, y = self._model_info.profile(**args)
453        return x, 1e-6*y
454
455    def setParam(self, name, value):
456        # type: (str, float) -> None
457        """
458        Set the value of a model parameter
459
460        :param name: name of the parameter
461        :param value: value of the parameter
462
463        """
464        # Look for dispersion parameters
465        toks = name.split('.')
466        if len(toks) == 2:
467            for item in self.dispersion.keys():
468                if item == toks[0]:
469                    for par in self.dispersion[item]:
470                        if par == toks[1]:
471                            self.dispersion[item][par] = value
472                            return
473        else:
474            # Look for standard parameter
475            for item in self.params.keys():
476                if item == name:
477                    self.params[item] = value
478                    return
479
480        raise ValueError("Model does not contain parameter %s" % name)
481
482    def getParam(self, name):
483        # type: (str) -> float
484        """
485        Set the value of a model parameter
486
487        :param name: name of the parameter
488
489        """
490        # Look for dispersion parameters
491        toks = name.split('.')
492        if len(toks) == 2:
493            for item in self.dispersion.keys():
494                if item == toks[0]:
495                    for par in self.dispersion[item]:
496                        if par == toks[1]:
497                            return self.dispersion[item][par]
498        else:
499            # Look for standard parameter
500            for item in self.params.keys():
501                if item == name:
502                    return self.params[item]
503
504        raise ValueError("Model does not contain parameter %s" % name)
505
506    def getParamList(self):
507        # type: () -> Sequence[str]
508        """
509        Return a list of all available parameters for the model
510        """
511        param_list = list(self.params.keys())
512        # WARNING: Extending the list with the dispersion parameters
513        param_list.extend(self.getDispParamList())
514        return param_list
515
516    def getDispParamList(self):
517        # type: () -> Sequence[str]
518        """
519        Return a list of polydispersity parameters for the model
520        """
521        # TODO: fix test so that parameter order doesn't matter
522        ret = ['%s.%s' % (p_name, ext)
523               for p_name in self.dispersion.keys()
524               for ext in ('npts', 'nsigmas', 'width')]
525        #print(ret)
526        return ret
527
528    def clone(self):
529        # type: () -> "SasviewModel"
530        """ Return a identical copy of self """
531        return deepcopy(self)
532
533    def run(self, x=0.0):
534        # type: (Union[float, (float, float), List[float]]) -> float
535        """
536        Evaluate the model
537
538        :param x: input q, or [q,phi]
539
540        :return: scattering function P(q)
541
542        **DEPRECATED**: use calculate_Iq instead
543        """
544        if isinstance(x, (list, tuple)):
545            # pylint: disable=unpacking-non-sequence
546            q, phi = x
547            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
548        else:
549            return self.calculate_Iq([x])[0]
550
551
552    def runXY(self, x=0.0):
553        # type: (Union[float, (float, float), List[float]]) -> float
554        """
555        Evaluate the model in cartesian coordinates
556
557        :param x: input q, or [qx, qy]
558
559        :return: scattering function P(q)
560
561        **DEPRECATED**: use calculate_Iq instead
562        """
563        if isinstance(x, (list, tuple)):
564            return self.calculate_Iq([x[0]], [x[1]])[0]
565        else:
566            return self.calculate_Iq([x])[0]
567
568    def evalDistribution(self, qdist):
569        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
570        r"""
571        Evaluate a distribution of q-values.
572
573        :param qdist: array of q or a list of arrays [qx,qy]
574
575        * For 1D, a numpy array is expected as input
576
577        ::
578
579            evalDistribution(q)
580
581          where *q* is a numpy array.
582
583        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
584
585        ::
586
587              qx = [ qx[0], qx[1], qx[2], ....]
588              qy = [ qy[0], qy[1], qy[2], ....]
589
590        If the model is 1D only, then
591
592        .. math::
593
594            q = \sqrt{q_x^2+q_y^2}
595
596        """
597        if isinstance(qdist, (list, tuple)):
598            # Check whether we have a list of ndarrays [qx,qy]
599            qx, qy = qdist
600            return self.calculate_Iq(qx, qy)
601
602        elif isinstance(qdist, np.ndarray):
603            # We have a simple 1D distribution of q-values
604            return self.calculate_Iq(qdist)
605
606        else:
607            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
608                            % type(qdist))
609
610    def calc_composition_models(self, qx):
611        """
612        returns parts of the composition model or None if not a composition
613        model.
614        """
615        # TODO: have calculate_Iq return the intermediates.
616        #
617        # The current interface causes calculate_Iq() to be called twice,
618        # once to get the combined result and again to get the intermediate
619        # results.  This is necessary for now.
620        # Long term, the solution is to change the interface to calculate_Iq
621        # so that it returns a results object containing all the bits:
622        #     the A, B, C, ... of the composition model (and any subcomponents?)
623        #     the P and S of the product model,
624        #     the combined model before resolution smearing,
625        #     the sasmodel before sesans conversion,
626        #     the oriented 2D model used to fit oriented usans data,
627        #     the final I(q),
628        #     ...
629        #
630        # Have the model calculator add all of these blindly to the data
631        # tree, and update the graphs which contain them.  The fitter
632        # needs to be updated to use the I(q) value only, ignoring the rest.
633        #
634        # The simple fix of returning the existing intermediate results
635        # will not work for a couple of reasons: (1) another thread may
636        # sneak in to compute its own results before calc_composition_models
637        # is called, and (2) calculate_Iq is currently called three times:
638        # once with q, once with q values before qmin and once with q values
639        # after q max.  Both of these should be addressed before
640        # replacing this code.
641        composition = self._model_info.composition
642        if composition and composition[0] == 'product': # only P*S for now
643            with calculation_lock:
644                self._calculate_Iq(qx)
645                return self._intermediate_results
646        else:
647            return None
648
649    def calculate_Iq(self, qx, qy=None):
650        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
651        """
652        Calculate Iq for one set of q with the current parameters.
653
654        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
655
656        This should NOT be used for fitting since it copies the *q* vectors
657        to the card for each evaluation.
658        """
659        ## uncomment the following when trying to debug the uncoordinated calls
660        ## to calculate_Iq
661        #if calculation_lock.locked():
662        #    logger.info("calculation waiting for another thread to complete")
663        #    logger.info("\n".join(traceback.format_stack()))
664
665        with calculation_lock:
666            return self._calculate_Iq(qx, qy)
667
668    def _calculate_Iq(self, qx, qy=None):
669        if self._model is None:
670            self._model = core.build_model(self._model_info)
671        if qy is not None:
672            q_vectors = [np.asarray(qx), np.asarray(qy)]
673        else:
674            q_vectors = [np.asarray(qx)]
675        calculator = self._model.make_kernel(q_vectors)
676        parameters = self._model_info.parameters
677        pairs = [self._get_weights(p) for p in parameters.call_parameters]
678        #weights.plot_weights(self._model_info, pairs)
679        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
680        #call_details.show()
681        #print("================ parameters ==================")
682        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
683        #for k, p in enumerate(self._model_info.parameters.call_parameters):
684        #    print(k, p.name, *pairs[k])
685        #print("params", self.params)
686        #print("values", values)
687        #print("is_mag", is_magnetic)
688        result = calculator(call_details, values, cutoff=self.cutoff,
689                            magnetic=is_magnetic)
690        #print("result", result)
691        self._intermediate_results = getattr(calculator, 'results', None)
692        calculator.release()
693        #self._model.release()
694        return result
695
696    def calculate_ER(self):
697        # type: () -> float
698        """
699        Calculate the effective radius for P(q)*S(q)
700
701        :return: the value of the effective radius
702        """
703        if self._model_info.ER is None:
704            return 1.0
705        else:
706            value, weight = self._dispersion_mesh()
707            fv = self._model_info.ER(*value)
708            #print(values[0].shape, weights.shape, fv.shape)
709            return np.sum(weight * fv) / np.sum(weight)
710
711    def calculate_VR(self):
712        # type: () -> float
713        """
714        Calculate the volf ratio for P(q)*S(q)
715
716        :return: the value of the volf ratio
717        """
718        if self._model_info.VR is None:
719            return 1.0
720        else:
721            value, weight = self._dispersion_mesh()
722            whole, part = self._model_info.VR(*value)
723            return np.sum(weight * part) / np.sum(weight * whole)
724
725    def set_dispersion(self, parameter, dispersion):
726        # type: (str, weights.Dispersion) -> None
727        """
728        Set the dispersion object for a model parameter
729
730        :param parameter: name of the parameter [string]
731        :param dispersion: dispersion object of type Dispersion
732        """
733        if parameter in self.params:
734            # TODO: Store the disperser object directly in the model.
735            # The current method of relying on the sasview GUI to
736            # remember them is kind of funky.
737            # Note: can't seem to get disperser parameters from sasview
738            # (1) Could create a sasview model that has not yet been
739            # converted, assign the disperser to one of its polydisperse
740            # parameters, then retrieve the disperser parameters from the
741            # sasview model.
742            # (2) Could write a disperser parameter retriever in sasview.
743            # (3) Could modify sasview to use sasmodels.weights dispersers.
744            # For now, rely on the fact that the sasview only ever uses
745            # new dispersers in the set_dispersion call and create a new
746            # one instead of trying to assign parameters.
747            self.dispersion[parameter] = dispersion.get_pars()
748        else:
749            raise ValueError("%r is not a dispersity or orientation parameter"
750                             % parameter)
751
752    def _dispersion_mesh(self):
753        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
754        """
755        Create a mesh grid of dispersion parameters and weights.
756
757        Returns [p1,p2,...],w where pj is a vector of values for parameter j
758        and w is a vector containing the products for weights for each
759        parameter set in the vector.
760        """
761        pars = [self._get_weights(p)
762                for p in self._model_info.parameters.call_parameters
763                if p.type == 'volume']
764        return dispersion_mesh(self._model_info, pars)
765
766    def _get_weights(self, par):
767        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
768        """
769        Return dispersion weights for parameter
770        """
771        if par.name not in self.params:
772            if par.name == self.multiplicity_info.control:
773                return self.multiplicity, [self.multiplicity], [1.0]
774            else:
775                # For hidden parameters use default values.  This sets
776                # scale=1 and background=0 for structure factors
777                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
778                return default, [default], [1.0]
779        elif par.polydisperse:
780            value = self.params[par.name]
781            dis = self.dispersion[par.name]
782            if dis['type'] == 'array':
783                dispersity, weight = dis['values'], dis['weights']
784            else:
785                dispersity, weight = weights.get_weights(
786                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
787                    value, par.limits, par.relative_pd)
788            return value, dispersity, weight
789        else:
790            value = self.params[par.name]
791            return value, [value], [1.0]
792
793def test_cylinder():
794    # type: () -> float
795    """
796    Test that the cylinder model runs, returning the value at [0.1,0.1].
797    """
798    Cylinder = _make_standard_model('cylinder')
799    cylinder = Cylinder()
800    return cylinder.evalDistribution([0.1, 0.1])
801
802def test_structure_factor():
803    # type: () -> float
804    """
805    Test that 2-D hardsphere model runs and doesn't produce NaN.
806    """
807    Model = _make_standard_model('hardsphere')
808    model = Model()
809    value2d = model.evalDistribution([0.1, 0.1])
810    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
811    #print("hardsphere", value1d, value2d)
812    if np.isnan(value1d) or np.isnan(value2d):
813        raise ValueError("hardsphere returns nan")
814
815def test_product():
816    # type: () -> float
817    """
818    Test that 2-D hardsphere model runs and doesn't produce NaN.
819    """
820    S = _make_standard_model('hayter_msa')()
821    P = _make_standard_model('cylinder')()
822    model = MultiplicationModel(P, S)
823    value = model.evalDistribution([0.1, 0.1])
824    if np.isnan(value):
825        raise ValueError("cylinder*hatyer_msa returns null")
826
827def test_rpa():
828    # type: () -> float
829    """
830    Test that the 2-D RPA model runs
831    """
832    RPA = _make_standard_model('rpa')
833    rpa = RPA(3)
834    return rpa.evalDistribution([0.1, 0.1])
835
836def test_empty_distribution():
837    # type: () -> None
838    """
839    Make sure that sasmodels returns NaN when there are no polydispersity points
840    """
841    Cylinder = _make_standard_model('cylinder')
842    cylinder = Cylinder()
843    cylinder.setParam('radius', -1.0)
844    cylinder.setParam('background', 0.)
845    Iq = cylinder.evalDistribution(np.asarray([0.1]))
846    assert Iq[0] == 0., "empty distribution fails"
847
848def test_model_list():
849    # type: () -> None
850    """
851    Make sure that all models build as sasview models
852    """
853    from .exception import annotate_exception
854    for name in core.list_models():
855        try:
856            _make_standard_model(name)
857        except:
858            annotate_exception("when loading "+name)
859            raise
860
861def test_old_name():
862    # type: () -> None
863    """
864    Load and run cylinder model as sas-models-CylinderModel
865    """
866    if not SUPPORT_OLD_STYLE_PLUGINS:
867        return
868    try:
869        # if sasview is not on the path then don't try to test it
870        import sas
871    except ImportError:
872        return
873    load_standard_models()
874    from sas.models.CylinderModel import CylinderModel
875    CylinderModel().evalDistribution([0.1, 0.1])
876
877def magnetic_demo():
878    Model = _make_standard_model('sphere')
879    model = Model()
880    model.setParam('M0:sld', 8)
881    q = np.linspace(-0.35, 0.35, 500)
882    qx, qy = np.meshgrid(q, q)
883    result = model.calculate_Iq(qx.flatten(), qy.flatten())
884    result = result.reshape(qx.shape)
885
886    import pylab
887    pylab.imshow(np.log(result + 0.001))
888    pylab.show()
889
890if __name__ == "__main__":
891    print("cylinder(0.1,0.1)=%g"%test_cylinder())
892    #magnetic_demo()
893    #test_product()
894    #test_structure_factor()
895    #print("rpa:", test_rpa())
896    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.