source: sasmodels/sasmodels/sasview_model.py @ 0535624

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 0535624 was 0535624, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

set background=0 in modelinfo.defaults, and simplify sasview model wrapper

  • Property mode set to 100644
File size: 32.1 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33# pylint: disable=unused-import
34try:
35    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
36                        List, Optional, Union, Callable)
37    from .modelinfo import ModelInfo, Parameter
38    from .kernel import KernelModel
39    MultiplicityInfoType = NamedTuple(
40        'MultiplicityInfo',
41        [("number", int), ("control", str), ("choices", List[str]),
42         ("x_axis_label", str)])
43    SasviewModelType = Callable[[int], "SasviewModel"]
44except ImportError:
45    pass
46# pylint: enable=unused-import
47
48logger = logging.getLogger(__name__)
49
50calculation_lock = thread.allocate_lock()
51
52#: True if pre-existing plugins, with the old names and parameters, should
53#: continue to be supported.
54SUPPORT_OLD_STYLE_PLUGINS = True
55
56# TODO: separate x_axis_label from multiplicity info
57MultiplicityInfo = collections.namedtuple(
58    'MultiplicityInfo',
59    ["number", "control", "choices", "x_axis_label"],
60)
61
62#: set of defined models (standard and custom)
63MODELS = {}  # type: Dict[str, SasviewModelType]
64#: custom model {path: model} mapping so we can check timestamps
65MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
66
67def find_model(modelname):
68    # type: (str) -> SasviewModelType
69    """
70    Find a model by name.  If the model name ends in py, try loading it from
71    custom models, otherwise look for it in the list of builtin models.
72    """
73    # TODO: used by sum/product model to load an existing model
74    # TODO: doesn't handle custom models properly
75    if modelname.endswith('.py'):
76        return load_custom_model(modelname)
77    elif modelname in MODELS:
78        return MODELS[modelname]
79    else:
80        raise ValueError("unknown model %r"%modelname)
81
82
83# TODO: figure out how to say that the return type is a subclass
84def load_standard_models():
85    # type: () -> List[SasviewModelType]
86    """
87    Load and return the list of predefined models.
88
89    If there is an error loading a model, then a traceback is logged and the
90    model is not returned.
91    """
92    for name in core.list_models():
93        try:
94            MODELS[name] = _make_standard_model(name)
95        except Exception:
96            logger.error(traceback.format_exc())
97    if SUPPORT_OLD_STYLE_PLUGINS:
98        _register_old_models()
99
100    return list(MODELS.values())
101
102
103def load_custom_model(path):
104    # type: (str) -> SasviewModelType
105    """
106    Load a custom model given the model path.
107    """
108    model = MODEL_BY_PATH.get(path, None)
109    if model is not None and model.timestamp == getmtime(path):
110        #logger.info("Model already loaded %s", path)
111        return model
112
113    #logger.info("Loading model %s", path)
114    kernel_module = custom.load_custom_kernel_module(path)
115    if hasattr(kernel_module, 'Model'):
116        model = kernel_module.Model
117        # Old style models do not set the name in the class attributes, so
118        # set it here; this name will be overridden when the object is created
119        # with an instance variable that has the same value.
120        if model.name == "":
121            model.name = splitext(basename(path))[0]
122        if not hasattr(model, 'filename'):
123            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
124        if not hasattr(model, 'id'):
125            model.id = splitext(basename(model.filename))[0]
126    else:
127        model_info = modelinfo.make_model_info(kernel_module)
128        model = make_model_from_info(model_info)
129    model.timestamp = getmtime(path)
130
131    # If a model name already exists and we are loading a different model,
132    # use the model file name as the model name.
133    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
134        _previous_name = model.name
135        model.name = model.id
136
137        # If the new model name is still in the model list (for instance,
138        # if we put a cylinder.py in our plug-in directory), then append
139        # an identifier.
140        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
141            model.name = model.id + '_user'
142        logger.info("Model %s already exists: using %s [%s]",
143                    _previous_name, model.name, model.filename)
144
145    MODELS[model.name] = model
146    MODEL_BY_PATH[path] = model
147    return model
148
149
150def make_model_from_info(model_info):
151    # type: (ModelInfo) -> SasviewModelType
152    """
153    Convert *model_info* into a SasView model wrapper.
154    """
155    def __init__(self, multiplicity=None):
156        SasviewModel.__init__(self, multiplicity=multiplicity)
157    attrs = _generate_model_attributes(model_info)
158    attrs['__init__'] = __init__
159    attrs['filename'] = model_info.filename
160    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
161    return ConstructedModel
162
163
164def _make_standard_model(name):
165    # type: (str) -> SasviewModelType
166    """
167    Load the sasview model defined by *name*.
168
169    *name* can be a standard model name or a path to a custom model.
170
171    Returns a class that can be used directly as a sasview model.
172    """
173    kernel_module = generate.load_kernel_module(name)
174    model_info = modelinfo.make_model_info(kernel_module)
175    return make_model_from_info(model_info)
176
177
178def _register_old_models():
179    # type: () -> None
180    """
181    Place the new models into sasview under the old names.
182
183    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
184    is available to the plugin modules.
185    """
186    import sys
187    import sas   # needed in order to set sas.models
188    import sas.sascalc.fit
189    sys.modules['sas.models'] = sas.sascalc.fit
190    sas.models = sas.sascalc.fit
191    import sas.models
192    from sasmodels.conversion_table import CONVERSION_TABLE
193
194    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
195        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
196        new_name = new_name.split(':')[0]
197        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
198        module_attrs = {old_name: find_model(new_name)}
199        ConstructedModule = type(old_name, (), module_attrs)
200        old_path = 'sas.models.' + old_name
201        setattr(sas.models, old_path, ConstructedModule)
202        sys.modules[old_path] = ConstructedModule
203
204
205def MultiplicationModel(form_factor, structure_factor):
206    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
207    """
208    Returns a constructed product model from form_factor and structure_factor.
209    """
210    model_info = product.make_product_info(form_factor._model_info,
211                                           structure_factor._model_info)
212    ConstructedModel = make_model_from_info(model_info)
213    return ConstructedModel(form_factor.multiplicity)
214
215
216def _generate_model_attributes(model_info):
217    # type: (ModelInfo) -> Dict[str, Any]
218    """
219    Generate the class attributes for the model.
220
221    This should include all the information necessary to query the model
222    details so that you do not need to instantiate a model to query it.
223
224    All the attributes should be immutable to avoid accidents.
225    """
226
227    # TODO: allow model to override axis labels input/output name/unit
228
229    # Process multiplicity
230    non_fittable = []  # type: List[str]
231    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
232    variants = MultiplicityInfo(0, "", [], xlabel)
233    for p in model_info.parameters.kernel_parameters:
234        if p.name == model_info.control:
235            non_fittable.append(p.name)
236            variants = MultiplicityInfo(
237                len(p.choices) if p.choices else int(p.limits[1]),
238                p.name, p.choices, xlabel
239            )
240            break
241
242    # Only a single drop-down list parameter available
243    fun_list = []
244    for p in model_info.parameters.kernel_parameters:
245        if p.choices:
246            fun_list = p.choices
247            if p.length > 1:
248                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
249            break
250
251    # Organize parameter sets
252    orientation_params = []
253    magnetic_params = []
254    fixed = []
255    for p in model_info.parameters.user_parameters({}, is2d=True):
256        if p.type == 'orientation':
257            orientation_params.append(p.name)
258            orientation_params.append(p.name+".width")
259            fixed.append(p.name+".width")
260        elif p.type == 'magnetic':
261            orientation_params.append(p.name)
262            magnetic_params.append(p.name)
263            fixed.append(p.name+".width")
264
265
266    # Build class dictionary
267    attrs = {}  # type: Dict[str, Any]
268    attrs['_model_info'] = model_info
269    attrs['name'] = model_info.name
270    attrs['id'] = model_info.id
271    attrs['description'] = model_info.description
272    attrs['category'] = model_info.category
273    attrs['is_structure_factor'] = model_info.structure_factor
274    attrs['is_form_factor'] = model_info.ER is not None
275    attrs['is_multiplicity_model'] = variants[0] > 1
276    attrs['multiplicity_info'] = variants
277    attrs['orientation_params'] = tuple(orientation_params)
278    attrs['magnetic_params'] = tuple(magnetic_params)
279    attrs['fixed'] = tuple(fixed)
280    attrs['non_fittable'] = tuple(non_fittable)
281    attrs['fun_list'] = tuple(fun_list)
282
283    return attrs
284
285class SasviewModel(object):
286    """
287    Sasview wrapper for opencl/ctypes model.
288    """
289    # Model parameters for the specific model are set in the class constructor
290    # via the _generate_model_attributes function, which subclasses
291    # SasviewModel.  They are included here for typing and documentation
292    # purposes.
293    _model = None       # type: KernelModel
294    _model_info = None  # type: ModelInfo
295    #: load/save name for the model
296    id = None           # type: str
297    #: display name for the model
298    name = None         # type: str
299    #: short model description
300    description = None  # type: str
301    #: default model category
302    category = None     # type: str
303
304    #: names of the orientation parameters in the order they appear
305    orientation_params = None # type: List[str]
306    #: names of the magnetic parameters in the order they appear
307    magnetic_params = None    # type: List[str]
308    #: names of the fittable parameters
309    fixed = None              # type: List[str]
310    # TODO: the attribute fixed is ill-named
311
312    # Axis labels
313    input_name = "Q"
314    input_unit = "A^{-1}"
315    output_name = "Intensity"
316    output_unit = "cm^{-1}"
317
318    #: default cutoff for polydispersity
319    cutoff = 1e-5
320
321    # Note: Use non-mutable values for class attributes to avoid errors
322    #: parameters that are not fitted
323    non_fittable = ()        # type: Sequence[str]
324
325    #: True if model should appear as a structure factor
326    is_structure_factor = False
327    #: True if model should appear as a form factor
328    is_form_factor = False
329    #: True if model has multiplicity
330    is_multiplicity_model = False
331    #: Multiplicity information
332    multiplicity_info = None # type: MultiplicityInfoType
333
334    # Per-instance variables
335    #: parameter {name: value} mapping
336    params = None      # type: Dict[str, float]
337    #: values for dispersion width, npts, nsigmas and type
338    dispersion = None  # type: Dict[str, Any]
339    #: units and limits for each parameter
340    details = None     # type: Dict[str, Sequence[Any]]
341    #                  # actual type is Dict[str, List[str, float, float]]
342    #: multiplicity value, or None if no multiplicity on the model
343    multiplicity = None     # type: Optional[int]
344    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
345    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
346
347    def __init__(self, multiplicity=None):
348        # type: (Optional[int]) -> None
349
350        # TODO: _persistency_dict to persistency_dict throughout sasview
351        # TODO: refactor multiplicity to encompass variants
352        # TODO: dispersion should be a class
353        # TODO: refactor multiplicity info
354        # TODO: separate profile view from multiplicity
355        # The button label, x and y axis labels and scale need to be under
356        # the control of the model, not the fit page.  Maximum flexibility,
357        # the fit page would supply the canvas and the profile could plot
358        # how it wants, but this assumes matplotlib.  Next level is that
359        # we provide some sort of data description including title, labels
360        # and lines to plot.
361
362        # Get the list of hidden parameters given the multiplicity
363        # Don't include multiplicity in the list of parameters
364        self.multiplicity = multiplicity
365        if multiplicity is not None:
366            hidden = self._model_info.get_hidden_parameters(multiplicity)
367            hidden |= set([self.multiplicity_info.control])
368        else:
369            hidden = set()
370        if self._model_info.structure_factor:
371            hidden.add('scale')
372            hidden.add('background')
373
374        self._persistency_dict = {}
375        self.params = collections.OrderedDict()
376        self.dispersion = collections.OrderedDict()
377        self.details = {}
378        for p in self._model_info.parameters.user_parameters({}, is2d=True):
379            if p.name in hidden:
380                continue
381            self.params[p.name] = p.default
382            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
383            if p.polydisperse:
384                self.details[p.id+".width"] = [
385                    "", 0.0, 1.0 if p.relative_pd else np.inf
386                ]
387                self.dispersion[p.name] = {
388                    'width': 0,
389                    'npts': 35,
390                    'nsigmas': 3,
391                    'type': 'gaussian',
392                }
393
394    def __get_state__(self):
395        # type: () -> Dict[str, Any]
396        state = self.__dict__.copy()
397        state.pop('_model')
398        # May need to reload model info on set state since it has pointers
399        # to python implementations of Iq, etc.
400        #state.pop('_model_info')
401        return state
402
403    def __set_state__(self, state):
404        # type: (Dict[str, Any]) -> None
405        self.__dict__ = state
406        self._model = None
407
408    def __str__(self):
409        # type: () -> str
410        """
411        :return: string representation
412        """
413        return self.name
414
415    def is_fittable(self, par_name):
416        # type: (str) -> bool
417        """
418        Check if a given parameter is fittable or not
419
420        :param par_name: the parameter name to check
421        """
422        return par_name in self.fixed
423        #For the future
424        #return self.params[str(par_name)].is_fittable()
425
426
427    def getProfile(self):
428        # type: () -> (np.ndarray, np.ndarray)
429        """
430        Get SLD profile
431
432        : return: (z, beta) where z is a list of depth of the transition points
433                beta is a list of the corresponding SLD values
434        """
435        args = {} # type: Dict[str, Any]
436        for p in self._model_info.parameters.kernel_parameters:
437            if p.id == self.multiplicity_info.control:
438                value = float(self.multiplicity)
439            elif p.length == 1:
440                value = self.params.get(p.id, np.NaN)
441            else:
442                value = np.array([self.params.get(p.id+str(k), np.NaN)
443                                  for k in range(1, p.length+1)])
444            args[p.id] = value
445
446        x, y = self._model_info.profile(**args)
447        return x, 1e-6*y
448
449    def setParam(self, name, value):
450        # type: (str, float) -> None
451        """
452        Set the value of a model parameter
453
454        :param name: name of the parameter
455        :param value: value of the parameter
456
457        """
458        # Look for dispersion parameters
459        toks = name.split('.')
460        if len(toks) == 2:
461            for item in self.dispersion.keys():
462                if item == toks[0]:
463                    for par in self.dispersion[item]:
464                        if par == toks[1]:
465                            self.dispersion[item][par] = value
466                            return
467        else:
468            # Look for standard parameter
469            for item in self.params.keys():
470                if item == name:
471                    self.params[item] = value
472                    return
473
474        raise ValueError("Model does not contain parameter %s" % name)
475
476    def getParam(self, name):
477        # type: (str) -> float
478        """
479        Set the value of a model parameter
480
481        :param name: name of the parameter
482
483        """
484        # Look for dispersion parameters
485        toks = name.split('.')
486        if len(toks) == 2:
487            for item in self.dispersion.keys():
488                if item == toks[0]:
489                    for par in self.dispersion[item]:
490                        if par == toks[1]:
491                            return self.dispersion[item][par]
492        else:
493            # Look for standard parameter
494            for item in self.params.keys():
495                if item == name:
496                    return self.params[item]
497
498        raise ValueError("Model does not contain parameter %s" % name)
499
500    def getParamList(self):
501        # type: () -> Sequence[str]
502        """
503        Return a list of all available parameters for the model
504        """
505        param_list = list(self.params.keys())
506        # WARNING: Extending the list with the dispersion parameters
507        param_list.extend(self.getDispParamList())
508        return param_list
509
510    def getDispParamList(self):
511        # type: () -> Sequence[str]
512        """
513        Return a list of polydispersity parameters for the model
514        """
515        # TODO: fix test so that parameter order doesn't matter
516        ret = ['%s.%s' % (p_name, ext)
517               for p_name in self.dispersion.keys()
518               for ext in ('npts', 'nsigmas', 'width')]
519        #print(ret)
520        return ret
521
522    def clone(self):
523        # type: () -> "SasviewModel"
524        """ Return a identical copy of self """
525        return deepcopy(self)
526
527    def run(self, x=0.0):
528        # type: (Union[float, (float, float), List[float]]) -> float
529        """
530        Evaluate the model
531
532        :param x: input q, or [q,phi]
533
534        :return: scattering function P(q)
535
536        **DEPRECATED**: use calculate_Iq instead
537        """
538        if isinstance(x, (list, tuple)):
539            # pylint: disable=unpacking-non-sequence
540            q, phi = x
541            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
542        else:
543            return self.calculate_Iq([x])[0]
544
545
546    def runXY(self, x=0.0):
547        # type: (Union[float, (float, float), List[float]]) -> float
548        """
549        Evaluate the model in cartesian coordinates
550
551        :param x: input q, or [qx, qy]
552
553        :return: scattering function P(q)
554
555        **DEPRECATED**: use calculate_Iq instead
556        """
557        if isinstance(x, (list, tuple)):
558            return self.calculate_Iq([x[0]], [x[1]])[0]
559        else:
560            return self.calculate_Iq([x])[0]
561
562    def evalDistribution(self, qdist):
563        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
564        r"""
565        Evaluate a distribution of q-values.
566
567        :param qdist: array of q or a list of arrays [qx,qy]
568
569        * For 1D, a numpy array is expected as input
570
571        ::
572
573            evalDistribution(q)
574
575          where *q* is a numpy array.
576
577        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
578
579        ::
580
581              qx = [ qx[0], qx[1], qx[2], ....]
582              qy = [ qy[0], qy[1], qy[2], ....]
583
584        If the model is 1D only, then
585
586        .. math::
587
588            q = \sqrt{q_x^2+q_y^2}
589
590        """
591        if isinstance(qdist, (list, tuple)):
592            # Check whether we have a list of ndarrays [qx,qy]
593            qx, qy = qdist
594            return self.calculate_Iq(qx, qy)
595
596        elif isinstance(qdist, np.ndarray):
597            # We have a simple 1D distribution of q-values
598            return self.calculate_Iq(qdist)
599
600        else:
601            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
602                            % type(qdist))
603
604    def calc_composition_models(self, qx):
605        """
606        returns parts of the composition model or None if not a composition
607        model.
608        """
609        # TODO: have calculate_Iq return the intermediates.
610        #
611        # The current interface causes calculate_Iq() to be called twice,
612        # once to get the combined result and again to get the intermediate
613        # results.  This is necessary for now.
614        # Long term, the solution is to change the interface to calculate_Iq
615        # so that it returns a results object containing all the bits:
616        #     the A, B, C, ... of the composition model (and any subcomponents?)
617        #     the P and S of the product model,
618        #     the combined model before resolution smearing,
619        #     the sasmodel before sesans conversion,
620        #     the oriented 2D model used to fit oriented usans data,
621        #     the final I(q),
622        #     ...
623        #
624        # Have the model calculator add all of these blindly to the data
625        # tree, and update the graphs which contain them.  The fitter
626        # needs to be updated to use the I(q) value only, ignoring the rest.
627        #
628        # The simple fix of returning the existing intermediate results
629        # will not work for a couple of reasons: (1) another thread may
630        # sneak in to compute its own results before calc_composition_models
631        # is called, and (2) calculate_Iq is currently called three times:
632        # once with q, once with q values before qmin and once with q values
633        # after q max.  Both of these should be addressed before
634        # replacing this code.
635        composition = self._model_info.composition
636        if composition and composition[0] == 'product': # only P*S for now
637            with calculation_lock:
638                self._calculate_Iq(qx)
639                return self._intermediate_results
640        else:
641            return None
642
643    def calculate_Iq(self, qx, qy=None):
644        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
645        """
646        Calculate Iq for one set of q with the current parameters.
647
648        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
649
650        This should NOT be used for fitting since it copies the *q* vectors
651        to the card for each evaluation.
652        """
653        ## uncomment the following when trying to debug the uncoordinated calls
654        ## to calculate_Iq
655        #if calculation_lock.locked():
656        #    logger.info("calculation waiting for another thread to complete")
657        #    logger.info("\n".join(traceback.format_stack()))
658
659        with calculation_lock:
660            return self._calculate_Iq(qx, qy)
661
662    def _calculate_Iq(self, qx, qy=None):
663        if self._model is None:
664            self._model = core.build_model(self._model_info)
665        if qy is not None:
666            q_vectors = [np.asarray(qx), np.asarray(qy)]
667        else:
668            q_vectors = [np.asarray(qx)]
669        calculator = self._model.make_kernel(q_vectors)
670        parameters = self._model_info.parameters
671        pairs = [self._get_weights(p) for p in parameters.call_parameters]
672        #weights.plot_weights(self._model_info, pairs)
673        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
674        #call_details.show()
675        #print("================ parameters ==================")
676        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
677        #for k, p in enumerate(self._model_info.parameters.call_parameters):
678        #    print(k, p.name, *pairs[k])
679        #print("params", self.params)
680        #print("values", values)
681        #print("is_mag", is_magnetic)
682        result = calculator(call_details, values, cutoff=self.cutoff,
683                            magnetic=is_magnetic)
684        #print("result", result)
685        self._intermediate_results = getattr(calculator, 'results', None)
686        calculator.release()
687        #self._model.release()
688        return result
689
690    def calculate_ER(self):
691        # type: () -> float
692        """
693        Calculate the effective radius for P(q)*S(q)
694
695        :return: the value of the effective radius
696        """
697        if self._model_info.ER is None:
698            return 1.0
699        else:
700            value, weight = self._dispersion_mesh()
701            fv = self._model_info.ER(*value)
702            #print(values[0].shape, weights.shape, fv.shape)
703            return np.sum(weight * fv) / np.sum(weight)
704
705    def calculate_VR(self):
706        # type: () -> float
707        """
708        Calculate the volf ratio for P(q)*S(q)
709
710        :return: the value of the volf ratio
711        """
712        if self._model_info.VR is None:
713            return 1.0
714        else:
715            value, weight = self._dispersion_mesh()
716            whole, part = self._model_info.VR(*value)
717            return np.sum(weight * part) / np.sum(weight * whole)
718
719    def set_dispersion(self, parameter, dispersion):
720        # type: (str, weights.Dispersion) -> None
721        """
722        Set the dispersion object for a model parameter
723
724        :param parameter: name of the parameter [string]
725        :param dispersion: dispersion object of type Dispersion
726        """
727        if parameter in self.params:
728            # TODO: Store the disperser object directly in the model.
729            # The current method of relying on the sasview GUI to
730            # remember them is kind of funky.
731            # Note: can't seem to get disperser parameters from sasview
732            # (1) Could create a sasview model that has not yet been
733            # converted, assign the disperser to one of its polydisperse
734            # parameters, then retrieve the disperser parameters from the
735            # sasview model.
736            # (2) Could write a disperser parameter retriever in sasview.
737            # (3) Could modify sasview to use sasmodels.weights dispersers.
738            # For now, rely on the fact that the sasview only ever uses
739            # new dispersers in the set_dispersion call and create a new
740            # one instead of trying to assign parameters.
741            self.dispersion[parameter] = dispersion.get_pars()
742        else:
743            raise ValueError("%r is not a dispersity or orientation parameter"
744                             % parameter)
745
746    def _dispersion_mesh(self):
747        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
748        """
749        Create a mesh grid of dispersion parameters and weights.
750
751        Returns [p1,p2,...],w where pj is a vector of values for parameter j
752        and w is a vector containing the products for weights for each
753        parameter set in the vector.
754        """
755        pars = [self._get_weights(p)
756                for p in self._model_info.parameters.call_parameters
757                if p.type == 'volume']
758        return dispersion_mesh(self._model_info, pars)
759
760    def _get_weights(self, par):
761        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
762        """
763        Return dispersion weights for parameter
764        """
765        if par.name not in self.params:
766            if par.name == self.multiplicity_info.control:
767                return self.multiplicity, [self.multiplicity], [1.0]
768            else:
769                # For hidden parameters use default values.  This sets
770                # scale=1 and background=0 for structure factors
771                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
772                return default, [default], [1.0]
773        elif par.polydisperse:
774            value = self.params[par.name]
775            dis = self.dispersion[par.name]
776            if dis['type'] == 'array':
777                dispersity, weight = dis['values'], dis['weights']
778            else:
779                dispersity, weight = weights.get_weights(
780                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
781                    value, par.limits, par.relative_pd)
782            return value, dispersity, weight
783        else:
784            value = self.params[par.name]
785            return value, [value], [1.0]
786
787def test_cylinder():
788    # type: () -> float
789    """
790    Test that the cylinder model runs, returning the value at [0.1,0.1].
791    """
792    Cylinder = _make_standard_model('cylinder')
793    cylinder = Cylinder()
794    return cylinder.evalDistribution([0.1, 0.1])
795
796def test_structure_factor():
797    # type: () -> float
798    """
799    Test that 2-D hardsphere model runs and doesn't produce NaN.
800    """
801    Model = _make_standard_model('hardsphere')
802    model = Model()
803    value2d = model.evalDistribution([0.1, 0.1])
804    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
805    #print("hardsphere", value1d, value2d)
806    if np.isnan(value1d) or np.isnan(value2d):
807        raise ValueError("hardsphere returns nan")
808
809def test_product():
810    # type: () -> float
811    """
812    Test that 2-D hardsphere model runs and doesn't produce NaN.
813    """
814    S = _make_standard_model('hayter_msa')()
815    P = _make_standard_model('cylinder')()
816    model = MultiplicationModel(P, S)
817    value = model.evalDistribution([0.1, 0.1])
818    if np.isnan(value):
819        raise ValueError("cylinder*hatyer_msa returns null")
820
821def test_rpa():
822    # type: () -> float
823    """
824    Test that the 2-D RPA model runs
825    """
826    RPA = _make_standard_model('rpa')
827    rpa = RPA(3)
828    return rpa.evalDistribution([0.1, 0.1])
829
830def test_empty_distribution():
831    # type: () -> None
832    """
833    Make sure that sasmodels returns NaN when there are no polydispersity points
834    """
835    Cylinder = _make_standard_model('cylinder')
836    cylinder = Cylinder()
837    cylinder.setParam('radius', -1.0)
838    cylinder.setParam('background', 0.)
839    Iq = cylinder.evalDistribution(np.asarray([0.1]))
840    assert Iq[0] == 0., "empty distribution fails"
841
842def test_model_list():
843    # type: () -> None
844    """
845    Make sure that all models build as sasview models
846    """
847    from .exception import annotate_exception
848    for name in core.list_models():
849        try:
850            _make_standard_model(name)
851        except:
852            annotate_exception("when loading "+name)
853            raise
854
855def test_old_name():
856    # type: () -> None
857    """
858    Load and run cylinder model as sas-models-CylinderModel
859    """
860    if not SUPPORT_OLD_STYLE_PLUGINS:
861        return
862    try:
863        # if sasview is not on the path then don't try to test it
864        import sas
865    except ImportError:
866        return
867    load_standard_models()
868    from sas.models.CylinderModel import CylinderModel
869    CylinderModel().evalDistribution([0.1, 0.1])
870
871def test_structure_factor_background():
872    # type: () -> None
873    """
874    Check that sasview model and direct model match, with background=0.
875    """
876    from .data import empty_data1D
877    from .core import load_model_info, build_model
878    from .direct_model import DirectModel
879
880    model_name = "hardsphere"
881    q = [0.0]
882
883    sasview_model = _make_standard_model(model_name)()
884    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
885
886    data = empty_data1D(q)
887    model_info = load_model_info(model_name)
888    model = build_model(model_info)
889    direct_model = DirectModel(data, model)
890    direct_value_zero_background = direct_model(background=0.0)
891
892    assert sasview_value == direct_value_zero_background
893
894    # Additionally check that direct value background defaults to zero
895    direct_value_default = direct_model()
896    assert sasview_value == direct_value_default
897
898
899def magnetic_demo():
900    Model = _make_standard_model('sphere')
901    model = Model()
902    model.setParam('M0:sld', 8)
903    q = np.linspace(-0.35, 0.35, 500)
904    qx, qy = np.meshgrid(q, q)
905    result = model.calculate_Iq(qx.flatten(), qy.flatten())
906    result = result.reshape(qx.shape)
907
908    import pylab
909    pylab.imshow(np.log(result + 0.001))
910    pylab.show()
911
912if __name__ == "__main__":
913    print("cylinder(0.1,0.1)=%g"%test_cylinder())
914    #magnetic_demo()
915    #test_product()
916    #test_structure_factor()
917    #print("rpa:", test_rpa())
918    #test_empty_distribution()
919    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.