source: sasmodels/sasmodels/sasview_model.py @ b171acd

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since b171acd was b171acd, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

move 'multiplicity' handling into sasview model. Refs #1022.

  • Property mode set to 100644
File size: 33.3 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33from .kernelcl import reset_environment
34
35# pylint: disable=unused-import
36try:
37    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
38                        List, Optional, Union, Callable)
39    from .modelinfo import ModelInfo, Parameter
40    from .kernel import KernelModel
41    MultiplicityInfoType = NamedTuple(
42        'MultiplicityInfo',
43        [("number", int), ("control", str), ("choices", List[str]),
44         ("x_axis_label", str)])
45    SasviewModelType = Callable[[int], "SasviewModel"]
46except ImportError:
47    pass
48# pylint: enable=unused-import
49
50logger = logging.getLogger(__name__)
51
52calculation_lock = thread.allocate_lock()
53
54#: True if pre-existing plugins, with the old names and parameters, should
55#: continue to be supported.
56SUPPORT_OLD_STYLE_PLUGINS = True
57
58# TODO: separate x_axis_label from multiplicity info
59MultiplicityInfo = collections.namedtuple(
60    'MultiplicityInfo',
61    ["number", "control", "choices", "x_axis_label"],
62)
63
64#: set of defined models (standard and custom)
65MODELS = {}  # type: Dict[str, SasviewModelType]
66# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
67#: custom model {path: model} mapping so we can check timestamps
68MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
69#: Track modules that we have loaded so we can determine whether the model
70#: has changed since we last reloaded.
71_CACHED_MODULE = {}  # type: Dict[str, "module"]
72
73def reset_environment():
74    # type: () -> None
75    """
76    Clear the compute engine context so that the GUI can change devices.
77
78    This removes all compiled kernels, even those that are active on fit
79    pages, but they will be restored the next time they are needed.
80    """
81    kernelcl.reset_environment()
82    for model in MODELS.values():
83        model._model = None
84
85def find_model(modelname):
86    # type: (str) -> SasviewModelType
87    """
88    Find a model by name.  If the model name ends in py, try loading it from
89    custom models, otherwise look for it in the list of builtin models.
90    """
91    # TODO: used by sum/product model to load an existing model
92    # TODO: doesn't handle custom models properly
93    if modelname.endswith('.py'):
94        return load_custom_model(modelname)
95    elif modelname in MODELS:
96        return MODELS[modelname]
97    else:
98        raise ValueError("unknown model %r"%modelname)
99
100
101# TODO: figure out how to say that the return type is a subclass
102def load_standard_models():
103    # type: () -> List[SasviewModelType]
104    """
105    Load and return the list of predefined models.
106
107    If there is an error loading a model, then a traceback is logged and the
108    model is not returned.
109    """
110    for name in core.list_models():
111        try:
112            MODELS[name] = _make_standard_model(name)
113        except Exception:
114            logger.error(traceback.format_exc())
115    if SUPPORT_OLD_STYLE_PLUGINS:
116        _register_old_models()
117
118    return list(MODELS.values())
119
120
121def load_custom_model(path):
122    # type: (str) -> SasviewModelType
123    """
124    Load a custom model given the model path.
125    """
126    #logger.info("Loading model %s", path)
127
128    # Load the kernel module.  This may already be cached by the loader, so
129    # only requires checking the timestamps of the dependents.
130    kernel_module = custom.load_custom_kernel_module(path)
131
132    # Check if the module has changed since we last looked.
133    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
134    _CACHED_MODULE[path] = kernel_module
135
136    # Turn the module into a model.  We need to do this in even if the
137    # model has already been loaded so that we can determine the model
138    # name and retrieve it from the MODELS cache.
139    model = getattr(kernel_module, 'Model', None)
140    if model is not None:
141        # Old style models do not set the name in the class attributes, so
142        # set it here; this name will be overridden when the object is created
143        # with an instance variable that has the same value.
144        if model.name == "":
145            model.name = splitext(basename(path))[0]
146        if not hasattr(model, 'filename'):
147            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
148        if not hasattr(model, 'id'):
149            model.id = splitext(basename(model.filename))[0]
150    else:
151        model_info = modelinfo.make_model_info(kernel_module)
152        model = make_model_from_info(model_info)
153
154    # If a model name already exists and we are loading a different model,
155    # use the model file name as the model name.
156    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
157        _previous_name = model.name
158        model.name = model.id
159
160        # If the new model name is still in the model list (for instance,
161        # if we put a cylinder.py in our plug-in directory), then append
162        # an identifier.
163        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
164            model.name = model.id + '_user'
165        logger.info("Model %s already exists: using %s [%s]",
166                    _previous_name, model.name, model.filename)
167
168    # Only update the model if the module has changed
169    if reloaded or model.name not in MODELS:
170        MODELS[model.name] = model
171
172    return MODELS[model.name]
173
174
175def make_model_from_info(model_info):
176    # type: (ModelInfo) -> SasviewModelType
177    """
178    Convert *model_info* into a SasView model wrapper.
179    """
180    def __init__(self, multiplicity=None):
181        SasviewModel.__init__(self, multiplicity=multiplicity)
182    attrs = _generate_model_attributes(model_info)
183    attrs['__init__'] = __init__
184    attrs['filename'] = model_info.filename
185    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
186    return ConstructedModel
187
188
189def _make_standard_model(name):
190    # type: (str) -> SasviewModelType
191    """
192    Load the sasview model defined by *name*.
193
194    *name* can be a standard model name or a path to a custom model.
195
196    Returns a class that can be used directly as a sasview model.
197    """
198    kernel_module = generate.load_kernel_module(name)
199    model_info = modelinfo.make_model_info(kernel_module)
200    return make_model_from_info(model_info)
201
202
203def _register_old_models():
204    # type: () -> None
205    """
206    Place the new models into sasview under the old names.
207
208    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
209    is available to the plugin modules.
210    """
211    import sys
212    import sas   # needed in order to set sas.models
213    import sas.sascalc.fit
214    sys.modules['sas.models'] = sas.sascalc.fit
215    sas.models = sas.sascalc.fit
216    import sas.models
217    from sasmodels.conversion_table import CONVERSION_TABLE
218
219    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
220        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
221        new_name = new_name.split(':')[0]
222        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
223        module_attrs = {old_name: find_model(new_name)}
224        ConstructedModule = type(old_name, (), module_attrs)
225        old_path = 'sas.models.' + old_name
226        setattr(sas.models, old_path, ConstructedModule)
227        sys.modules[old_path] = ConstructedModule
228
229
230def MultiplicationModel(form_factor, structure_factor):
231    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
232    """
233    Returns a constructed product model from form_factor and structure_factor.
234    """
235    model_info = product.make_product_info(form_factor._model_info,
236                                           structure_factor._model_info)
237    ConstructedModel = make_model_from_info(model_info)
238    return ConstructedModel(form_factor.multiplicity)
239
240
241def _generate_model_attributes(model_info):
242    # type: (ModelInfo) -> Dict[str, Any]
243    """
244    Generate the class attributes for the model.
245
246    This should include all the information necessary to query the model
247    details so that you do not need to instantiate a model to query it.
248
249    All the attributes should be immutable to avoid accidents.
250    """
251
252    # TODO: allow model to override axis labels input/output name/unit
253
254    # Process multiplicity
255    control_pars = [p.id for p in model_info.parameters.kernel_parameters
256                    if p.is_control]
257    control = control_pars[0] if control_pars else None
258    non_fittable = []  # type: List[str]
259    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
260    variants = MultiplicityInfo(0, "", [], xlabel)
261    for p in model_info.parameters.kernel_parameters:
262        if p.id == control:
263            non_fittable.append(p.name)
264            variants = MultiplicityInfo(
265                len(p.choices) if p.choices else int(p.limits[1]),
266                p.name, p.choices, xlabel
267            )
268            break
269
270    # Only a single drop-down list parameter available
271    fun_list = []
272    for p in model_info.parameters.kernel_parameters:
273        if p.choices:
274            fun_list = p.choices
275            if p.length > 1:
276                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
277            break
278
279    # Organize parameter sets
280    orientation_params = []
281    magnetic_params = []
282    fixed = []
283    for p in model_info.parameters.user_parameters({}, is2d=True):
284        if p.type == 'orientation':
285            orientation_params.append(p.name)
286            orientation_params.append(p.name+".width")
287            fixed.append(p.name+".width")
288        elif p.type == 'magnetic':
289            orientation_params.append(p.name)
290            magnetic_params.append(p.name)
291            fixed.append(p.name+".width")
292
293
294    # Build class dictionary
295    attrs = {}  # type: Dict[str, Any]
296    attrs['_model_info'] = model_info
297    attrs['name'] = model_info.name
298    attrs['id'] = model_info.id
299    attrs['description'] = model_info.description
300    attrs['category'] = model_info.category
301    attrs['is_structure_factor'] = model_info.structure_factor
302    attrs['is_form_factor'] = model_info.ER is not None
303    attrs['is_multiplicity_model'] = variants[0] > 1
304    attrs['multiplicity_info'] = variants
305    attrs['orientation_params'] = tuple(orientation_params)
306    attrs['magnetic_params'] = tuple(magnetic_params)
307    attrs['fixed'] = tuple(fixed)
308    attrs['non_fittable'] = tuple(non_fittable)
309    attrs['fun_list'] = tuple(fun_list)
310
311    return attrs
312
313class SasviewModel(object):
314    """
315    Sasview wrapper for opencl/ctypes model.
316    """
317    # Model parameters for the specific model are set in the class constructor
318    # via the _generate_model_attributes function, which subclasses
319    # SasviewModel.  They are included here for typing and documentation
320    # purposes.
321    _model = None       # type: KernelModel
322    _model_info = None  # type: ModelInfo
323    #: load/save name for the model
324    id = None           # type: str
325    #: display name for the model
326    name = None         # type: str
327    #: short model description
328    description = None  # type: str
329    #: default model category
330    category = None     # type: str
331
332    #: names of the orientation parameters in the order they appear
333    orientation_params = None # type: List[str]
334    #: names of the magnetic parameters in the order they appear
335    magnetic_params = None    # type: List[str]
336    #: names of the fittable parameters
337    fixed = None              # type: List[str]
338    # TODO: the attribute fixed is ill-named
339
340    # Axis labels
341    input_name = "Q"
342    input_unit = "A^{-1}"
343    output_name = "Intensity"
344    output_unit = "cm^{-1}"
345
346    #: default cutoff for polydispersity
347    cutoff = 1e-5
348
349    # Note: Use non-mutable values for class attributes to avoid errors
350    #: parameters that are not fitted
351    non_fittable = ()        # type: Sequence[str]
352
353    #: True if model should appear as a structure factor
354    is_structure_factor = False
355    #: True if model should appear as a form factor
356    is_form_factor = False
357    #: True if model has multiplicity
358    is_multiplicity_model = False
359    #: Multiplicity information
360    multiplicity_info = None # type: MultiplicityInfoType
361
362    # Per-instance variables
363    #: parameter {name: value} mapping
364    params = None      # type: Dict[str, float]
365    #: values for dispersion width, npts, nsigmas and type
366    dispersion = None  # type: Dict[str, Any]
367    #: units and limits for each parameter
368    details = None     # type: Dict[str, Sequence[Any]]
369    #                  # actual type is Dict[str, List[str, float, float]]
370    #: multiplicity value, or None if no multiplicity on the model
371    multiplicity = None     # type: Optional[int]
372    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
373    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
374
375    def __init__(self, multiplicity=None):
376        # type: (Optional[int]) -> None
377
378        # TODO: _persistency_dict to persistency_dict throughout sasview
379        # TODO: refactor multiplicity to encompass variants
380        # TODO: dispersion should be a class
381        # TODO: refactor multiplicity info
382        # TODO: separate profile view from multiplicity
383        # The button label, x and y axis labels and scale need to be under
384        # the control of the model, not the fit page.  Maximum flexibility,
385        # the fit page would supply the canvas and the profile could plot
386        # how it wants, but this assumes matplotlib.  Next level is that
387        # we provide some sort of data description including title, labels
388        # and lines to plot.
389
390        # Get the list of hidden parameters given the multiplicity
391        # Don't include multiplicity in the list of parameters
392        self.multiplicity = multiplicity
393        if multiplicity is not None:
394            hidden = self._model_info.get_hidden_parameters(multiplicity)
395            hidden |= set([self.multiplicity_info.control])
396        else:
397            hidden = set()
398        if self._model_info.structure_factor:
399            hidden.add('scale')
400            hidden.add('background')
401            self._model_info.parameters.defaults['background'] = 0.
402
403        # Update the parameter lists to exclude any hidden parameters
404        self.magnetic_params = tuple(pname for pname in self.magnetic_params
405                                     if pname not in hidden)
406        self.orientation_params = tuple(pname for pname in self.orientation_params
407                                        if pname not in hidden)
408
409        self._persistency_dict = {}
410        self.params = collections.OrderedDict()
411        self.dispersion = collections.OrderedDict()
412        self.details = {}
413        for p in self._model_info.parameters.user_parameters({}, is2d=True):
414            if p.name in hidden:
415                continue
416            self.params[p.name] = p.default
417            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
418            if p.polydisperse:
419                self.details[p.id+".width"] = [
420                    "", 0.0, 1.0 if p.relative_pd else np.inf
421                ]
422                self.dispersion[p.name] = {
423                    'width': 0,
424                    'npts': 35,
425                    'nsigmas': 3,
426                    'type': 'gaussian',
427                }
428
429    def __get_state__(self):
430        # type: () -> Dict[str, Any]
431        state = self.__dict__.copy()
432        state.pop('_model')
433        # May need to reload model info on set state since it has pointers
434        # to python implementations of Iq, etc.
435        #state.pop('_model_info')
436        return state
437
438    def __set_state__(self, state):
439        # type: (Dict[str, Any]) -> None
440        self.__dict__ = state
441        self._model = None
442
443    def __str__(self):
444        # type: () -> str
445        """
446        :return: string representation
447        """
448        return self.name
449
450    def is_fittable(self, par_name):
451        # type: (str) -> bool
452        """
453        Check if a given parameter is fittable or not
454
455        :param par_name: the parameter name to check
456        """
457        return par_name in self.fixed
458        #For the future
459        #return self.params[str(par_name)].is_fittable()
460
461
462    def getProfile(self):
463        # type: () -> (np.ndarray, np.ndarray)
464        """
465        Get SLD profile
466
467        : return: (z, beta) where z is a list of depth of the transition points
468                beta is a list of the corresponding SLD values
469        """
470        args = {} # type: Dict[str, Any]
471        for p in self._model_info.parameters.kernel_parameters:
472            if p.id == self.multiplicity_info.control:
473                value = float(self.multiplicity)
474            elif p.length == 1:
475                value = self.params.get(p.id, np.NaN)
476            else:
477                value = np.array([self.params.get(p.id+str(k), np.NaN)
478                                  for k in range(1, p.length+1)])
479            args[p.id] = value
480
481        x, y = self._model_info.profile(**args)
482        return x, 1e-6*y
483
484    def setParam(self, name, value):
485        # type: (str, float) -> None
486        """
487        Set the value of a model parameter
488
489        :param name: name of the parameter
490        :param value: value of the parameter
491
492        """
493        # Look for dispersion parameters
494        toks = name.split('.')
495        if len(toks) == 2:
496            for item in self.dispersion.keys():
497                if item == toks[0]:
498                    for par in self.dispersion[item]:
499                        if par == toks[1]:
500                            self.dispersion[item][par] = value
501                            return
502        else:
503            # Look for standard parameter
504            for item in self.params.keys():
505                if item == name:
506                    self.params[item] = value
507                    return
508
509        raise ValueError("Model does not contain parameter %s" % name)
510
511    def getParam(self, name):
512        # type: (str) -> float
513        """
514        Set the value of a model parameter
515
516        :param name: name of the parameter
517
518        """
519        # Look for dispersion parameters
520        toks = name.split('.')
521        if len(toks) == 2:
522            for item in self.dispersion.keys():
523                if item == toks[0]:
524                    for par in self.dispersion[item]:
525                        if par == toks[1]:
526                            return self.dispersion[item][par]
527        else:
528            # Look for standard parameter
529            for item in self.params.keys():
530                if item == name:
531                    return self.params[item]
532
533        raise ValueError("Model does not contain parameter %s" % name)
534
535    def getParamList(self):
536        # type: () -> Sequence[str]
537        """
538        Return a list of all available parameters for the model
539        """
540        param_list = list(self.params.keys())
541        # WARNING: Extending the list with the dispersion parameters
542        param_list.extend(self.getDispParamList())
543        return param_list
544
545    def getDispParamList(self):
546        # type: () -> Sequence[str]
547        """
548        Return a list of polydispersity parameters for the model
549        """
550        # TODO: fix test so that parameter order doesn't matter
551        ret = ['%s.%s' % (p_name, ext)
552               for p_name in self.dispersion.keys()
553               for ext in ('npts', 'nsigmas', 'width')]
554        #print(ret)
555        return ret
556
557    def clone(self):
558        # type: () -> "SasviewModel"
559        """ Return a identical copy of self """
560        return deepcopy(self)
561
562    def run(self, x=0.0):
563        # type: (Union[float, (float, float), List[float]]) -> float
564        """
565        Evaluate the model
566
567        :param x: input q, or [q,phi]
568
569        :return: scattering function P(q)
570
571        **DEPRECATED**: use calculate_Iq instead
572        """
573        if isinstance(x, (list, tuple)):
574            # pylint: disable=unpacking-non-sequence
575            q, phi = x
576            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
577        else:
578            return self.calculate_Iq([x])[0]
579
580
581    def runXY(self, x=0.0):
582        # type: (Union[float, (float, float), List[float]]) -> float
583        """
584        Evaluate the model in cartesian coordinates
585
586        :param x: input q, or [qx, qy]
587
588        :return: scattering function P(q)
589
590        **DEPRECATED**: use calculate_Iq instead
591        """
592        if isinstance(x, (list, tuple)):
593            return self.calculate_Iq([x[0]], [x[1]])[0]
594        else:
595            return self.calculate_Iq([x])[0]
596
597    def evalDistribution(self, qdist):
598        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
599        r"""
600        Evaluate a distribution of q-values.
601
602        :param qdist: array of q or a list of arrays [qx,qy]
603
604        * For 1D, a numpy array is expected as input
605
606        ::
607
608            evalDistribution(q)
609
610          where *q* is a numpy array.
611
612        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
613
614        ::
615
616              qx = [ qx[0], qx[1], qx[2], ....]
617              qy = [ qy[0], qy[1], qy[2], ....]
618
619        If the model is 1D only, then
620
621        .. math::
622
623            q = \sqrt{q_x^2+q_y^2}
624
625        """
626        if isinstance(qdist, (list, tuple)):
627            # Check whether we have a list of ndarrays [qx,qy]
628            qx, qy = qdist
629            return self.calculate_Iq(qx, qy)
630
631        elif isinstance(qdist, np.ndarray):
632            # We have a simple 1D distribution of q-values
633            return self.calculate_Iq(qdist)
634
635        else:
636            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
637                            % type(qdist))
638
639    def calc_composition_models(self, qx):
640        """
641        returns parts of the composition model or None if not a composition
642        model.
643        """
644        # TODO: have calculate_Iq return the intermediates.
645        #
646        # The current interface causes calculate_Iq() to be called twice,
647        # once to get the combined result and again to get the intermediate
648        # results.  This is necessary for now.
649        # Long term, the solution is to change the interface to calculate_Iq
650        # so that it returns a results object containing all the bits:
651        #     the A, B, C, ... of the composition model (and any subcomponents?)
652        #     the P and S of the product model,
653        #     the combined model before resolution smearing,
654        #     the sasmodel before sesans conversion,
655        #     the oriented 2D model used to fit oriented usans data,
656        #     the final I(q),
657        #     ...
658        #
659        # Have the model calculator add all of these blindly to the data
660        # tree, and update the graphs which contain them.  The fitter
661        # needs to be updated to use the I(q) value only, ignoring the rest.
662        #
663        # The simple fix of returning the existing intermediate results
664        # will not work for a couple of reasons: (1) another thread may
665        # sneak in to compute its own results before calc_composition_models
666        # is called, and (2) calculate_Iq is currently called three times:
667        # once with q, once with q values before qmin and once with q values
668        # after q max.  Both of these should be addressed before
669        # replacing this code.
670        composition = self._model_info.composition
671        if composition and composition[0] == 'product': # only P*S for now
672            with calculation_lock:
673                self._calculate_Iq(qx)
674                return self._intermediate_results
675        else:
676            return None
677
678    def calculate_Iq(self, qx, qy=None):
679        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
680        """
681        Calculate Iq for one set of q with the current parameters.
682
683        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
684
685        This should NOT be used for fitting since it copies the *q* vectors
686        to the card for each evaluation.
687        """
688        ## uncomment the following when trying to debug the uncoordinated calls
689        ## to calculate_Iq
690        #if calculation_lock.locked():
691        #    logger.info("calculation waiting for another thread to complete")
692        #    logger.info("\n".join(traceback.format_stack()))
693
694        with calculation_lock:
695            return self._calculate_Iq(qx, qy)
696
697    def _calculate_Iq(self, qx, qy=None):
698        if self._model is None:
699            # Only need one copy of the compiled kernel regardless of how many
700            # times it is used, so store it in the class.  Also, to reset the
701            # compute engine, need to clear out all existing compiled kernels,
702            # which is much easier to do if we store them in the class.
703            self.__class__._model = core.build_model(self._model_info)
704        if qy is not None:
705            q_vectors = [np.asarray(qx), np.asarray(qy)]
706        else:
707            q_vectors = [np.asarray(qx)]
708        calculator = self._model.make_kernel(q_vectors)
709        parameters = self._model_info.parameters
710        pairs = [self._get_weights(p) for p in parameters.call_parameters]
711        #weights.plot_weights(self._model_info, pairs)
712        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
713        #call_details.show()
714        #print("================ parameters ==================")
715        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
716        #for k, p in enumerate(self._model_info.parameters.call_parameters):
717        #    print(k, p.name, *pairs[k])
718        #print("params", self.params)
719        #print("values", values)
720        #print("is_mag", is_magnetic)
721        result = calculator(call_details, values, cutoff=self.cutoff,
722                            magnetic=is_magnetic)
723        #print("result", result)
724        self._intermediate_results = getattr(calculator, 'results', None)
725        calculator.release()
726        #self._model.release()
727        return result
728
729    def calculate_ER(self):
730        # type: () -> float
731        """
732        Calculate the effective radius for P(q)*S(q)
733
734        :return: the value of the effective radius
735        """
736        if self._model_info.ER is None:
737            return 1.0
738        else:
739            value, weight = self._dispersion_mesh()
740            fv = self._model_info.ER(*value)
741            #print(values[0].shape, weights.shape, fv.shape)
742            return np.sum(weight * fv) / np.sum(weight)
743
744    def calculate_VR(self):
745        # type: () -> float
746        """
747        Calculate the volf ratio for P(q)*S(q)
748
749        :return: the value of the volf ratio
750        """
751        if self._model_info.VR is None:
752            return 1.0
753        else:
754            value, weight = self._dispersion_mesh()
755            whole, part = self._model_info.VR(*value)
756            return np.sum(weight * part) / np.sum(weight * whole)
757
758    def set_dispersion(self, parameter, dispersion):
759        # type: (str, weights.Dispersion) -> None
760        """
761        Set the dispersion object for a model parameter
762
763        :param parameter: name of the parameter [string]
764        :param dispersion: dispersion object of type Dispersion
765        """
766        if parameter in self.params:
767            # TODO: Store the disperser object directly in the model.
768            # The current method of relying on the sasview GUI to
769            # remember them is kind of funky.
770            # Note: can't seem to get disperser parameters from sasview
771            # (1) Could create a sasview model that has not yet been
772            # converted, assign the disperser to one of its polydisperse
773            # parameters, then retrieve the disperser parameters from the
774            # sasview model.
775            # (2) Could write a disperser parameter retriever in sasview.
776            # (3) Could modify sasview to use sasmodels.weights dispersers.
777            # For now, rely on the fact that the sasview only ever uses
778            # new dispersers in the set_dispersion call and create a new
779            # one instead of trying to assign parameters.
780            self.dispersion[parameter] = dispersion.get_pars()
781        else:
782            raise ValueError("%r is not a dispersity or orientation parameter"
783                             % parameter)
784
785    def _dispersion_mesh(self):
786        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
787        """
788        Create a mesh grid of dispersion parameters and weights.
789
790        Returns [p1,p2,...],w where pj is a vector of values for parameter j
791        and w is a vector containing the products for weights for each
792        parameter set in the vector.
793        """
794        pars = [self._get_weights(p)
795                for p in self._model_info.parameters.call_parameters
796                if p.type == 'volume']
797        return dispersion_mesh(self._model_info, pars)
798
799    def _get_weights(self, par):
800        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
801        """
802        Return dispersion weights for parameter
803        """
804        if par.name not in self.params:
805            if par.id == self.multiplicity_info.control:
806                return self.multiplicity, [self.multiplicity], [1.0]
807            else:
808                # For hidden parameters use default values.  This sets
809                # scale=1 and background=0 for structure factors
810                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
811                return default, [default], [1.0]
812        elif par.polydisperse:
813            value = self.params[par.name]
814            dis = self.dispersion[par.name]
815            if dis['type'] == 'array':
816                dispersity, weight = dis['values'], dis['weights']
817            else:
818                dispersity, weight = weights.get_weights(
819                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
820                    value, par.limits, par.relative_pd)
821            return value, dispersity, weight
822        else:
823            value = self.params[par.name]
824            return value, [value], [1.0]
825
826    @classmethod
827    def runTests(cls):
828        """
829        Run any tests built into the model and captures the test output.
830
831        Returns success flag and output
832        """
833        from .model_test import check_model
834        return check_model(cls._model_info)
835
836def test_cylinder():
837    # type: () -> float
838    """
839    Test that the cylinder model runs, returning the value at [0.1,0.1].
840    """
841    Cylinder = _make_standard_model('cylinder')
842    cylinder = Cylinder()
843    return cylinder.evalDistribution([0.1, 0.1])
844
845def test_structure_factor():
846    # type: () -> float
847    """
848    Test that 2-D hardsphere model runs and doesn't produce NaN.
849    """
850    Model = _make_standard_model('hardsphere')
851    model = Model()
852    value2d = model.evalDistribution([0.1, 0.1])
853    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
854    #print("hardsphere", value1d, value2d)
855    if np.isnan(value1d) or np.isnan(value2d):
856        raise ValueError("hardsphere returns nan")
857
858def test_product():
859    # type: () -> float
860    """
861    Test that 2-D hardsphere model runs and doesn't produce NaN.
862    """
863    S = _make_standard_model('hayter_msa')()
864    P = _make_standard_model('cylinder')()
865    model = MultiplicationModel(P, S)
866    value = model.evalDistribution([0.1, 0.1])
867    if np.isnan(value):
868        raise ValueError("cylinder*hatyer_msa returns null")
869
870def test_rpa():
871    # type: () -> float
872    """
873    Test that the 2-D RPA model runs
874    """
875    RPA = _make_standard_model('rpa')
876    rpa = RPA(3)
877    return rpa.evalDistribution([0.1, 0.1])
878
879def test_empty_distribution():
880    # type: () -> None
881    """
882    Make sure that sasmodels returns NaN when there are no polydispersity points
883    """
884    Cylinder = _make_standard_model('cylinder')
885    cylinder = Cylinder()
886    cylinder.setParam('radius', -1.0)
887    cylinder.setParam('background', 0.)
888    Iq = cylinder.evalDistribution(np.asarray([0.1]))
889    assert Iq[0] == 0., "empty distribution fails"
890
891def test_model_list():
892    # type: () -> None
893    """
894    Make sure that all models build as sasview models
895    """
896    from .exception import annotate_exception
897    for name in core.list_models():
898        try:
899            _make_standard_model(name)
900        except:
901            annotate_exception("when loading "+name)
902            raise
903
904def test_old_name():
905    # type: () -> None
906    """
907    Load and run cylinder model as sas-models-CylinderModel
908    """
909    if not SUPPORT_OLD_STYLE_PLUGINS:
910        return
911    try:
912        # if sasview is not on the path then don't try to test it
913        import sas
914    except ImportError:
915        return
916    load_standard_models()
917    from sas.models.CylinderModel import CylinderModel
918    CylinderModel().evalDistribution([0.1, 0.1])
919
920def magnetic_demo():
921    Model = _make_standard_model('sphere')
922    model = Model()
923    model.setParam('sld_M0', 8)
924    q = np.linspace(-0.35, 0.35, 500)
925    qx, qy = np.meshgrid(q, q)
926    result = model.calculate_Iq(qx.flatten(), qy.flatten())
927    result = result.reshape(qx.shape)
928
929    import pylab
930    pylab.imshow(np.log(result + 0.001))
931    pylab.show()
932
933if __name__ == "__main__":
934    print("cylinder(0.1,0.1)=%g"%test_cylinder())
935    #magnetic_demo()
936    #test_product()
937    #test_structure_factor()
938    #print("rpa:", test_rpa())
939    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.