source: sasmodels/sasmodels/sasview_model.py @ 6da1d76

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 6da1d76 was 6da1d76, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

Merge branch 'master' into ticket-1157

  • Property mode set to 100644
File size: 32.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33# pylint: disable=unused-import
34try:
35    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
36                        List, Optional, Union, Callable)
37    from .modelinfo import ModelInfo, Parameter
38    from .kernel import KernelModel
39    MultiplicityInfoType = NamedTuple(
40        'MultiplicityInfo',
41        [("number", int), ("control", str), ("choices", List[str]),
42         ("x_axis_label", str)])
43    SasviewModelType = Callable[[int], "SasviewModel"]
44except ImportError:
45    pass
46# pylint: enable=unused-import
47
48logger = logging.getLogger(__name__)
49
50calculation_lock = thread.allocate_lock()
51
52#: True if pre-existing plugins, with the old names and parameters, should
53#: continue to be supported.
54SUPPORT_OLD_STYLE_PLUGINS = True
55
56# TODO: separate x_axis_label from multiplicity info
57MultiplicityInfo = collections.namedtuple(
58    'MultiplicityInfo',
59    ["number", "control", "choices", "x_axis_label"],
60)
61
62#: set of defined models (standard and custom)
63MODELS = {}  # type: Dict[str, SasviewModelType]
64# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
65#: custom model {path: model} mapping so we can check timestamps
66MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
67#: Track modules that we have loaded so we can determine whether the model
68#: has changed since we last reloaded.
69_CACHED_MODULE = {}  # type: Dict[str, "module"]
70
71def find_model(modelname):
72    # type: (str) -> SasviewModelType
73    """
74    Find a model by name.  If the model name ends in py, try loading it from
75    custom models, otherwise look for it in the list of builtin models.
76    """
77    # TODO: used by sum/product model to load an existing model
78    # TODO: doesn't handle custom models properly
79    if modelname.endswith('.py'):
80        return load_custom_model(modelname)
81    elif modelname in MODELS:
82        return MODELS[modelname]
83    else:
84        raise ValueError("unknown model %r"%modelname)
85
86
87# TODO: figure out how to say that the return type is a subclass
88def load_standard_models():
89    # type: () -> List[SasviewModelType]
90    """
91    Load and return the list of predefined models.
92
93    If there is an error loading a model, then a traceback is logged and the
94    model is not returned.
95    """
96    for name in core.list_models():
97        try:
98            MODELS[name] = _make_standard_model(name)
99        except Exception:
100            logger.error(traceback.format_exc())
101    if SUPPORT_OLD_STYLE_PLUGINS:
102        _register_old_models()
103
104    return list(MODELS.values())
105
106
107def load_custom_model(path):
108    # type: (str) -> SasviewModelType
109    """
110    Load a custom model given the model path.
111    """
112    #logger.info("Loading model %s", path)
113
114    # Load the kernel module.  This may already be cached by the loader, so
115    # only requires checking the timestamps of the dependents.
116    kernel_module = custom.load_custom_kernel_module(path)
117
118    # Check if the module has changed since we last looked.
119    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
120    _CACHED_MODULE[path] = kernel_module
121
122    # Turn the module into a model.  We need to do this in even if the
123    # model has already been loaded so that we can determine the model
124    # name and retrieve it from the MODELS cache.
125    model = getattr(kernel_module, 'Model', None)
126    if model is not None:
127        # Old style models do not set the name in the class attributes, so
128        # set it here; this name will be overridden when the object is created
129        # with an instance variable that has the same value.
130        if model.name == "":
131            model.name = splitext(basename(path))[0]
132        if not hasattr(model, 'filename'):
133            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
134        if not hasattr(model, 'id'):
135            model.id = splitext(basename(model.filename))[0]
136    else:
137        model_info = modelinfo.make_model_info(kernel_module)
138        model = make_model_from_info(model_info)
139
140    # If a model name already exists and we are loading a different model,
141    # use the model file name as the model name.
142    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
143        _previous_name = model.name
144        model.name = model.id
145
146        # If the new model name is still in the model list (for instance,
147        # if we put a cylinder.py in our plug-in directory), then append
148        # an identifier.
149        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
150            model.name = model.id + '_user'
151        logger.info("Model %s already exists: using %s [%s]",
152                    _previous_name, model.name, model.filename)
153
154    # Only update the model if the module has changed
155    if reloaded or model.name not in MODELS:
156        MODELS[model.name] = model
157
158    return MODELS[model.name]
159
160
161def make_model_from_info(model_info):
162    # type: (ModelInfo) -> SasviewModelType
163    """
164    Convert *model_info* into a SasView model wrapper.
165    """
166    def __init__(self, multiplicity=None):
167        SasviewModel.__init__(self, multiplicity=multiplicity)
168    attrs = _generate_model_attributes(model_info)
169    attrs['__init__'] = __init__
170    attrs['filename'] = model_info.filename
171    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
172    return ConstructedModel
173
174
175def _make_standard_model(name):
176    # type: (str) -> SasviewModelType
177    """
178    Load the sasview model defined by *name*.
179
180    *name* can be a standard model name or a path to a custom model.
181
182    Returns a class that can be used directly as a sasview model.
183    """
184    kernel_module = generate.load_kernel_module(name)
185    model_info = modelinfo.make_model_info(kernel_module)
186    return make_model_from_info(model_info)
187
188
189def _register_old_models():
190    # type: () -> None
191    """
192    Place the new models into sasview under the old names.
193
194    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
195    is available to the plugin modules.
196    """
197    import sys
198    import sas   # needed in order to set sas.models
199    import sas.sascalc.fit
200    sys.modules['sas.models'] = sas.sascalc.fit
201    sas.models = sas.sascalc.fit
202    import sas.models
203    from sasmodels.conversion_table import CONVERSION_TABLE
204
205    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
206        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
207        new_name = new_name.split(':')[0]
208        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
209        module_attrs = {old_name: find_model(new_name)}
210        ConstructedModule = type(old_name, (), module_attrs)
211        old_path = 'sas.models.' + old_name
212        setattr(sas.models, old_path, ConstructedModule)
213        sys.modules[old_path] = ConstructedModule
214
215
216def MultiplicationModel(form_factor, structure_factor):
217    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
218    """
219    Returns a constructed product model from form_factor and structure_factor.
220    """
221    model_info = product.make_product_info(form_factor._model_info,
222                                           structure_factor._model_info)
223    ConstructedModel = make_model_from_info(model_info)
224    return ConstructedModel(form_factor.multiplicity)
225
226
227def _generate_model_attributes(model_info):
228    # type: (ModelInfo) -> Dict[str, Any]
229    """
230    Generate the class attributes for the model.
231
232    This should include all the information necessary to query the model
233    details so that you do not need to instantiate a model to query it.
234
235    All the attributes should be immutable to avoid accidents.
236    """
237
238    # TODO: allow model to override axis labels input/output name/unit
239
240    # Process multiplicity
241    non_fittable = []  # type: List[str]
242    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
243    variants = MultiplicityInfo(0, "", [], xlabel)
244    for p in model_info.parameters.kernel_parameters:
245        if p.name == model_info.control:
246            non_fittable.append(p.name)
247            variants = MultiplicityInfo(
248                len(p.choices) if p.choices else int(p.limits[1]),
249                p.name, p.choices, xlabel
250            )
251            break
252
253    # Only a single drop-down list parameter available
254    fun_list = []
255    for p in model_info.parameters.kernel_parameters:
256        if p.choices:
257            fun_list = p.choices
258            if p.length > 1:
259                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
260            break
261
262    # Organize parameter sets
263    orientation_params = []
264    magnetic_params = []
265    fixed = []
266    for p in model_info.parameters.user_parameters({}, is2d=True):
267        if p.type == 'orientation':
268            orientation_params.append(p.name)
269            orientation_params.append(p.name+".width")
270            fixed.append(p.name+".width")
271        elif p.type == 'magnetic':
272            orientation_params.append(p.name)
273            magnetic_params.append(p.name)
274            fixed.append(p.name+".width")
275
276
277    # Build class dictionary
278    attrs = {}  # type: Dict[str, Any]
279    attrs['_model_info'] = model_info
280    attrs['name'] = model_info.name
281    attrs['id'] = model_info.id
282    attrs['description'] = model_info.description
283    attrs['category'] = model_info.category
284    attrs['is_structure_factor'] = model_info.structure_factor
285    attrs['is_form_factor'] = model_info.ER is not None
286    attrs['is_multiplicity_model'] = variants[0] > 1
287    attrs['multiplicity_info'] = variants
288    attrs['orientation_params'] = tuple(orientation_params)
289    attrs['magnetic_params'] = tuple(magnetic_params)
290    attrs['fixed'] = tuple(fixed)
291    attrs['non_fittable'] = tuple(non_fittable)
292    attrs['fun_list'] = tuple(fun_list)
293
294    return attrs
295
296class SasviewModel(object):
297    """
298    Sasview wrapper for opencl/ctypes model.
299    """
300    # Model parameters for the specific model are set in the class constructor
301    # via the _generate_model_attributes function, which subclasses
302    # SasviewModel.  They are included here for typing and documentation
303    # purposes.
304    _model = None       # type: KernelModel
305    _model_info = None  # type: ModelInfo
306    #: load/save name for the model
307    id = None           # type: str
308    #: display name for the model
309    name = None         # type: str
310    #: short model description
311    description = None  # type: str
312    #: default model category
313    category = None     # type: str
314
315    #: names of the orientation parameters in the order they appear
316    orientation_params = None # type: List[str]
317    #: names of the magnetic parameters in the order they appear
318    magnetic_params = None    # type: List[str]
319    #: names of the fittable parameters
320    fixed = None              # type: List[str]
321    # TODO: the attribute fixed is ill-named
322
323    # Axis labels
324    input_name = "Q"
325    input_unit = "A^{-1}"
326    output_name = "Intensity"
327    output_unit = "cm^{-1}"
328
329    #: default cutoff for polydispersity
330    cutoff = 1e-5
331
332    # Note: Use non-mutable values for class attributes to avoid errors
333    #: parameters that are not fitted
334    non_fittable = ()        # type: Sequence[str]
335
336    #: True if model should appear as a structure factor
337    is_structure_factor = False
338    #: True if model should appear as a form factor
339    is_form_factor = False
340    #: True if model has multiplicity
341    is_multiplicity_model = False
342    #: Multiplicity information
343    multiplicity_info = None # type: MultiplicityInfoType
344
345    # Per-instance variables
346    #: parameter {name: value} mapping
347    params = None      # type: Dict[str, float]
348    #: values for dispersion width, npts, nsigmas and type
349    dispersion = None  # type: Dict[str, Any]
350    #: units and limits for each parameter
351    details = None     # type: Dict[str, Sequence[Any]]
352    #                  # actual type is Dict[str, List[str, float, float]]
353    #: multiplicity value, or None if no multiplicity on the model
354    multiplicity = None     # type: Optional[int]
355    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
356    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
357
358    def __init__(self, multiplicity=None):
359        # type: (Optional[int]) -> None
360
361        # TODO: _persistency_dict to persistency_dict throughout sasview
362        # TODO: refactor multiplicity to encompass variants
363        # TODO: dispersion should be a class
364        # TODO: refactor multiplicity info
365        # TODO: separate profile view from multiplicity
366        # The button label, x and y axis labels and scale need to be under
367        # the control of the model, not the fit page.  Maximum flexibility,
368        # the fit page would supply the canvas and the profile could plot
369        # how it wants, but this assumes matplotlib.  Next level is that
370        # we provide some sort of data description including title, labels
371        # and lines to plot.
372
373        # Get the list of hidden parameters given the multiplicity
374        # Don't include multiplicity in the list of parameters
375        self.multiplicity = multiplicity
376        if multiplicity is not None:
377            hidden = self._model_info.get_hidden_parameters(multiplicity)
378            hidden |= set([self.multiplicity_info.control])
379        else:
380            hidden = set()
381        if self._model_info.structure_factor:
382            hidden.add('scale')
383            hidden.add('background')
384
385        self._persistency_dict = {}
386        self.params = collections.OrderedDict()
387        self.dispersion = collections.OrderedDict()
388        self.details = {}
389        for p in self._model_info.parameters.user_parameters({}, is2d=True):
390            if p.name in hidden:
391                continue
392            self.params[p.name] = p.default
393            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
394            if p.polydisperse:
395                self.details[p.id+".width"] = [
396                    "", 0.0, 1.0 if p.relative_pd else np.inf
397                ]
398                self.dispersion[p.name] = {
399                    'width': 0,
400                    'npts': 35,
401                    'nsigmas': 3,
402                    'type': 'gaussian',
403                }
404
405    def __get_state__(self):
406        # type: () -> Dict[str, Any]
407        state = self.__dict__.copy()
408        state.pop('_model')
409        # May need to reload model info on set state since it has pointers
410        # to python implementations of Iq, etc.
411        #state.pop('_model_info')
412        return state
413
414    def __set_state__(self, state):
415        # type: (Dict[str, Any]) -> None
416        self.__dict__ = state
417        self._model = None
418
419    def __str__(self):
420        # type: () -> str
421        """
422        :return: string representation
423        """
424        return self.name
425
426    def is_fittable(self, par_name):
427        # type: (str) -> bool
428        """
429        Check if a given parameter is fittable or not
430
431        :param par_name: the parameter name to check
432        """
433        return par_name in self.fixed
434        #For the future
435        #return self.params[str(par_name)].is_fittable()
436
437
438    def getProfile(self):
439        # type: () -> (np.ndarray, np.ndarray)
440        """
441        Get SLD profile
442
443        : return: (z, beta) where z is a list of depth of the transition points
444                beta is a list of the corresponding SLD values
445        """
446        args = {} # type: Dict[str, Any]
447        for p in self._model_info.parameters.kernel_parameters:
448            if p.id == self.multiplicity_info.control:
449                value = float(self.multiplicity)
450            elif p.length == 1:
451                value = self.params.get(p.id, np.NaN)
452            else:
453                value = np.array([self.params.get(p.id+str(k), np.NaN)
454                                  for k in range(1, p.length+1)])
455            args[p.id] = value
456
457        x, y = self._model_info.profile(**args)
458        return x, 1e-6*y
459
460    def setParam(self, name, value):
461        # type: (str, float) -> None
462        """
463        Set the value of a model parameter
464
465        :param name: name of the parameter
466        :param value: value of the parameter
467
468        """
469        # Look for dispersion parameters
470        toks = name.split('.')
471        if len(toks) == 2:
472            for item in self.dispersion.keys():
473                if item == toks[0]:
474                    for par in self.dispersion[item]:
475                        if par == toks[1]:
476                            self.dispersion[item][par] = value
477                            return
478        else:
479            # Look for standard parameter
480            for item in self.params.keys():
481                if item == name:
482                    self.params[item] = value
483                    return
484
485        raise ValueError("Model does not contain parameter %s" % name)
486
487    def getParam(self, name):
488        # type: (str) -> float
489        """
490        Set the value of a model parameter
491
492        :param name: name of the parameter
493
494        """
495        # Look for dispersion parameters
496        toks = name.split('.')
497        if len(toks) == 2:
498            for item in self.dispersion.keys():
499                if item == toks[0]:
500                    for par in self.dispersion[item]:
501                        if par == toks[1]:
502                            return self.dispersion[item][par]
503        else:
504            # Look for standard parameter
505            for item in self.params.keys():
506                if item == name:
507                    return self.params[item]
508
509        raise ValueError("Model does not contain parameter %s" % name)
510
511    def getParamList(self):
512        # type: () -> Sequence[str]
513        """
514        Return a list of all available parameters for the model
515        """
516        param_list = list(self.params.keys())
517        # WARNING: Extending the list with the dispersion parameters
518        param_list.extend(self.getDispParamList())
519        return param_list
520
521    def getDispParamList(self):
522        # type: () -> Sequence[str]
523        """
524        Return a list of polydispersity parameters for the model
525        """
526        # TODO: fix test so that parameter order doesn't matter
527        ret = ['%s.%s' % (p_name, ext)
528               for p_name in self.dispersion.keys()
529               for ext in ('npts', 'nsigmas', 'width')]
530        #print(ret)
531        return ret
532
533    def clone(self):
534        # type: () -> "SasviewModel"
535        """ Return a identical copy of self """
536        return deepcopy(self)
537
538    def run(self, x=0.0):
539        # type: (Union[float, (float, float), List[float]]) -> float
540        """
541        Evaluate the model
542
543        :param x: input q, or [q,phi]
544
545        :return: scattering function P(q)
546
547        **DEPRECATED**: use calculate_Iq instead
548        """
549        if isinstance(x, (list, tuple)):
550            # pylint: disable=unpacking-non-sequence
551            q, phi = x
552            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
553        else:
554            return self.calculate_Iq([x])[0]
555
556
557    def runXY(self, x=0.0):
558        # type: (Union[float, (float, float), List[float]]) -> float
559        """
560        Evaluate the model in cartesian coordinates
561
562        :param x: input q, or [qx, qy]
563
564        :return: scattering function P(q)
565
566        **DEPRECATED**: use calculate_Iq instead
567        """
568        if isinstance(x, (list, tuple)):
569            return self.calculate_Iq([x[0]], [x[1]])[0]
570        else:
571            return self.calculate_Iq([x])[0]
572
573    def evalDistribution(self, qdist):
574        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
575        r"""
576        Evaluate a distribution of q-values.
577
578        :param qdist: array of q or a list of arrays [qx,qy]
579
580        * For 1D, a numpy array is expected as input
581
582        ::
583
584            evalDistribution(q)
585
586          where *q* is a numpy array.
587
588        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
589
590        ::
591
592              qx = [ qx[0], qx[1], qx[2], ....]
593              qy = [ qy[0], qy[1], qy[2], ....]
594
595        If the model is 1D only, then
596
597        .. math::
598
599            q = \sqrt{q_x^2+q_y^2}
600
601        """
602        if isinstance(qdist, (list, tuple)):
603            # Check whether we have a list of ndarrays [qx,qy]
604            qx, qy = qdist
605            return self.calculate_Iq(qx, qy)
606
607        elif isinstance(qdist, np.ndarray):
608            # We have a simple 1D distribution of q-values
609            return self.calculate_Iq(qdist)
610
611        else:
612            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
613                            % type(qdist))
614
615    def calc_composition_models(self, qx):
616        """
617        returns parts of the composition model or None if not a composition
618        model.
619        """
620        # TODO: have calculate_Iq return the intermediates.
621        #
622        # The current interface causes calculate_Iq() to be called twice,
623        # once to get the combined result and again to get the intermediate
624        # results.  This is necessary for now.
625        # Long term, the solution is to change the interface to calculate_Iq
626        # so that it returns a results object containing all the bits:
627        #     the A, B, C, ... of the composition model (and any subcomponents?)
628        #     the P and S of the product model,
629        #     the combined model before resolution smearing,
630        #     the sasmodel before sesans conversion,
631        #     the oriented 2D model used to fit oriented usans data,
632        #     the final I(q),
633        #     ...
634        #
635        # Have the model calculator add all of these blindly to the data
636        # tree, and update the graphs which contain them.  The fitter
637        # needs to be updated to use the I(q) value only, ignoring the rest.
638        #
639        # The simple fix of returning the existing intermediate results
640        # will not work for a couple of reasons: (1) another thread may
641        # sneak in to compute its own results before calc_composition_models
642        # is called, and (2) calculate_Iq is currently called three times:
643        # once with q, once with q values before qmin and once with q values
644        # after q max.  Both of these should be addressed before
645        # replacing this code.
646        composition = self._model_info.composition
647        if composition and composition[0] == 'product': # only P*S for now
648            with calculation_lock:
649                self._calculate_Iq(qx)
650                return self._intermediate_results
651        else:
652            return None
653
654    def calculate_Iq(self, qx, qy=None):
655        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
656        """
657        Calculate Iq for one set of q with the current parameters.
658
659        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
660
661        This should NOT be used for fitting since it copies the *q* vectors
662        to the card for each evaluation.
663        """
664        ## uncomment the following when trying to debug the uncoordinated calls
665        ## to calculate_Iq
666        #if calculation_lock.locked():
667        #    logger.info("calculation waiting for another thread to complete")
668        #    logger.info("\n".join(traceback.format_stack()))
669
670        with calculation_lock:
671            return self._calculate_Iq(qx, qy)
672
673    def _calculate_Iq(self, qx, qy=None):
674        if self._model is None:
675            self._model = core.build_model(self._model_info)
676        if qy is not None:
677            q_vectors = [np.asarray(qx), np.asarray(qy)]
678        else:
679            q_vectors = [np.asarray(qx)]
680        calculator = self._model.make_kernel(q_vectors)
681        parameters = self._model_info.parameters
682        pairs = [self._get_weights(p) for p in parameters.call_parameters]
683        #weights.plot_weights(self._model_info, pairs)
684        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
685        #call_details.show()
686        #print("================ parameters ==================")
687        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
688        #for k, p in enumerate(self._model_info.parameters.call_parameters):
689        #    print(k, p.name, *pairs[k])
690        #print("params", self.params)
691        #print("values", values)
692        #print("is_mag", is_magnetic)
693        result = calculator(call_details, values, cutoff=self.cutoff,
694                            magnetic=is_magnetic)
695        #print("result", result)
696        self._intermediate_results = getattr(calculator, 'results', None)
697        calculator.release()
698        #self._model.release()
699        return result
700
701    def calculate_ER(self):
702        # type: () -> float
703        """
704        Calculate the effective radius for P(q)*S(q)
705
706        :return: the value of the effective radius
707        """
708        if self._model_info.ER is None:
709            return 1.0
710        else:
711            value, weight = self._dispersion_mesh()
712            fv = self._model_info.ER(*value)
713            #print(values[0].shape, weights.shape, fv.shape)
714            return np.sum(weight * fv) / np.sum(weight)
715
716    def calculate_VR(self):
717        # type: () -> float
718        """
719        Calculate the volf ratio for P(q)*S(q)
720
721        :return: the value of the volf ratio
722        """
723        if self._model_info.VR is None:
724            return 1.0
725        else:
726            value, weight = self._dispersion_mesh()
727            whole, part = self._model_info.VR(*value)
728            return np.sum(weight * part) / np.sum(weight * whole)
729
730    def set_dispersion(self, parameter, dispersion):
731        # type: (str, weights.Dispersion) -> None
732        """
733        Set the dispersion object for a model parameter
734
735        :param parameter: name of the parameter [string]
736        :param dispersion: dispersion object of type Dispersion
737        """
738        if parameter in self.params:
739            # TODO: Store the disperser object directly in the model.
740            # The current method of relying on the sasview GUI to
741            # remember them is kind of funky.
742            # Note: can't seem to get disperser parameters from sasview
743            # (1) Could create a sasview model that has not yet been
744            # converted, assign the disperser to one of its polydisperse
745            # parameters, then retrieve the disperser parameters from the
746            # sasview model.
747            # (2) Could write a disperser parameter retriever in sasview.
748            # (3) Could modify sasview to use sasmodels.weights dispersers.
749            # For now, rely on the fact that the sasview only ever uses
750            # new dispersers in the set_dispersion call and create a new
751            # one instead of trying to assign parameters.
752            self.dispersion[parameter] = dispersion.get_pars()
753        else:
754            raise ValueError("%r is not a dispersity or orientation parameter"
755                             % parameter)
756
757    def _dispersion_mesh(self):
758        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
759        """
760        Create a mesh grid of dispersion parameters and weights.
761
762        Returns [p1,p2,...],w where pj is a vector of values for parameter j
763        and w is a vector containing the products for weights for each
764        parameter set in the vector.
765        """
766        pars = [self._get_weights(p)
767                for p in self._model_info.parameters.call_parameters
768                if p.type == 'volume']
769        return dispersion_mesh(self._model_info, pars)
770
771    def _get_weights(self, par):
772        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
773        """
774        Return dispersion weights for parameter
775        """
776        if par.name not in self.params:
777            if par.name == self.multiplicity_info.control:
778                return self.multiplicity, [self.multiplicity], [1.0]
779            else:
780                # For hidden parameters use default values.  This sets
781                # scale=1 and background=0 for structure factors
782                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
783                return default, [default], [1.0]
784        elif par.polydisperse:
785            value = self.params[par.name]
786            dis = self.dispersion[par.name]
787            if dis['type'] == 'array':
788                dispersity, weight = dis['values'], dis['weights']
789            else:
790                dispersity, weight = weights.get_weights(
791                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
792                    value, par.limits, par.relative_pd)
793            return value, dispersity, weight
794        else:
795            value = self.params[par.name]
796            return value, [value], [1.0]
797
798def test_cylinder():
799    # type: () -> float
800    """
801    Test that the cylinder model runs, returning the value at [0.1,0.1].
802    """
803    Cylinder = _make_standard_model('cylinder')
804    cylinder = Cylinder()
805    return cylinder.evalDistribution([0.1, 0.1])
806
807def test_structure_factor():
808    # type: () -> float
809    """
810    Test that 2-D hardsphere model runs and doesn't produce NaN.
811    """
812    Model = _make_standard_model('hardsphere')
813    model = Model()
814    value2d = model.evalDistribution([0.1, 0.1])
815    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
816    #print("hardsphere", value1d, value2d)
817    if np.isnan(value1d) or np.isnan(value2d):
818        raise ValueError("hardsphere returns nan")
819
820def test_product():
821    # type: () -> float
822    """
823    Test that 2-D hardsphere model runs and doesn't produce NaN.
824    """
825    S = _make_standard_model('hayter_msa')()
826    P = _make_standard_model('cylinder')()
827    model = MultiplicationModel(P, S)
828    value = model.evalDistribution([0.1, 0.1])
829    if np.isnan(value):
830        raise ValueError("cylinder*hatyer_msa returns null")
831
832def test_rpa():
833    # type: () -> float
834    """
835    Test that the 2-D RPA model runs
836    """
837    RPA = _make_standard_model('rpa')
838    rpa = RPA(3)
839    return rpa.evalDistribution([0.1, 0.1])
840
841def test_empty_distribution():
842    # type: () -> None
843    """
844    Make sure that sasmodels returns NaN when there are no polydispersity points
845    """
846    Cylinder = _make_standard_model('cylinder')
847    cylinder = Cylinder()
848    cylinder.setParam('radius', -1.0)
849    cylinder.setParam('background', 0.)
850    Iq = cylinder.evalDistribution(np.asarray([0.1]))
851    assert Iq[0] == 0., "empty distribution fails"
852
853def test_model_list():
854    # type: () -> None
855    """
856    Make sure that all models build as sasview models
857    """
858    from .exception import annotate_exception
859    for name in core.list_models():
860        try:
861            _make_standard_model(name)
862        except:
863            annotate_exception("when loading "+name)
864            raise
865
866def test_old_name():
867    # type: () -> None
868    """
869    Load and run cylinder model as sas-models-CylinderModel
870    """
871    if not SUPPORT_OLD_STYLE_PLUGINS:
872        return
873    try:
874        # if sasview is not on the path then don't try to test it
875        import sas
876    except ImportError:
877        return
878    load_standard_models()
879    from sas.models.CylinderModel import CylinderModel
880    CylinderModel().evalDistribution([0.1, 0.1])
881
882def test_structure_factor_background():
883    # type: () -> None
884    """
885    Check that sasview model and direct model match, with background=0.
886    """
887    from .data import empty_data1D
888    from .core import load_model_info, build_model
889    from .direct_model import DirectModel
890
891    model_name = "hardsphere"
892    q = [0.0]
893
894    sasview_model = _make_standard_model(model_name)()
895    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
896
897    data = empty_data1D(q)
898    model_info = load_model_info(model_name)
899    model = build_model(model_info)
900    direct_model = DirectModel(data, model)
901    direct_value_zero_background = direct_model(background=0.0)
902
903    assert sasview_value == direct_value_zero_background
904
905    # Additionally check that direct value background defaults to zero
906    direct_value_default = direct_model()
907    assert sasview_value == direct_value_default
908
909
910def magnetic_demo():
911    Model = _make_standard_model('sphere')
912    model = Model()
913    model.setParam('M0:sld', 8)
914    q = np.linspace(-0.35, 0.35, 500)
915    qx, qy = np.meshgrid(q, q)
916    result = model.calculate_Iq(qx.flatten(), qy.flatten())
917    result = result.reshape(qx.shape)
918
919    import pylab
920    pylab.imshow(np.log(result + 0.001))
921    pylab.show()
922
923if __name__ == "__main__":
924    print("cylinder(0.1,0.1)=%g"%test_cylinder())
925    #magnetic_demo()
926    #test_product()
927    #test_structure_factor()
928    #print("rpa:", test_rpa())
929    #test_empty_distribution()
930    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.