source: sasmodels/sasmodels/sasview_model.py @ a4f1a73

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a4f1a73 was a4f1a73, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

clean out existing kernels when changing compute engine

  • Property mode set to 100644
File size: 33.2 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33from .kernelcl import reset_environment
34
35# pylint: disable=unused-import
36try:
37    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
38                        List, Optional, Union, Callable)
39    from .modelinfo import ModelInfo, Parameter
40    from .kernel import KernelModel
41    MultiplicityInfoType = NamedTuple(
42        'MultiplicityInfo',
43        [("number", int), ("control", str), ("choices", List[str]),
44         ("x_axis_label", str)])
45    SasviewModelType = Callable[[int], "SasviewModel"]
46except ImportError:
47    pass
48# pylint: enable=unused-import
49
50logger = logging.getLogger(__name__)
51
52calculation_lock = thread.allocate_lock()
53
54#: True if pre-existing plugins, with the old names and parameters, should
55#: continue to be supported.
56SUPPORT_OLD_STYLE_PLUGINS = True
57
58# TODO: separate x_axis_label from multiplicity info
59MultiplicityInfo = collections.namedtuple(
60    'MultiplicityInfo',
61    ["number", "control", "choices", "x_axis_label"],
62)
63
64#: set of defined models (standard and custom)
65MODELS = {}  # type: Dict[str, SasviewModelType]
66# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
67#: custom model {path: model} mapping so we can check timestamps
68MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
69#: Track modules that we have loaded so we can determine whether the model
70#: has changed since we last reloaded.
71_CACHED_MODULE = {}  # type: Dict[str, "module"]
72
73def reset_environment():
74    # type: () -> None
75    """
76    Clear the compute engine context so that the GUI can change devices.
77
78    This removes all compiled kernels, even those that are active on fit
79    pages, but they will be restored the next time they are needed.
80    """
81    kernelcl.reset_environment()
82    for model in MODELS.values():
83        model._model = None
84
85def find_model(modelname):
86    # type: (str) -> SasviewModelType
87    """
88    Find a model by name.  If the model name ends in py, try loading it from
89    custom models, otherwise look for it in the list of builtin models.
90    """
91    # TODO: used by sum/product model to load an existing model
92    # TODO: doesn't handle custom models properly
93    if modelname.endswith('.py'):
94        return load_custom_model(modelname)
95    elif modelname in MODELS:
96        return MODELS[modelname]
97    else:
98        raise ValueError("unknown model %r"%modelname)
99
100
101# TODO: figure out how to say that the return type is a subclass
102def load_standard_models():
103    # type: () -> List[SasviewModelType]
104    """
105    Load and return the list of predefined models.
106
107    If there is an error loading a model, then a traceback is logged and the
108    model is not returned.
109    """
110    for name in core.list_models():
111        try:
112            MODELS[name] = _make_standard_model(name)
113        except Exception:
114            logger.error(traceback.format_exc())
115    if SUPPORT_OLD_STYLE_PLUGINS:
116        _register_old_models()
117
118    return list(MODELS.values())
119
120
121def load_custom_model(path):
122    # type: (str) -> SasviewModelType
123    """
124    Load a custom model given the model path.
125    """
126    #logger.info("Loading model %s", path)
127
128    # Load the kernel module.  This may already be cached by the loader, so
129    # only requires checking the timestamps of the dependents.
130    kernel_module = custom.load_custom_kernel_module(path)
131
132    # Check if the module has changed since we last looked.
133    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
134    _CACHED_MODULE[path] = kernel_module
135
136    # Turn the module into a model.  We need to do this in even if the
137    # model has already been loaded so that we can determine the model
138    # name and retrieve it from the MODELS cache.
139    model = getattr(kernel_module, 'Model', None)
140    if model is not None:
141        # Old style models do not set the name in the class attributes, so
142        # set it here; this name will be overridden when the object is created
143        # with an instance variable that has the same value.
144        if model.name == "":
145            model.name = splitext(basename(path))[0]
146        if not hasattr(model, 'filename'):
147            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
148        if not hasattr(model, 'id'):
149            model.id = splitext(basename(model.filename))[0]
150    else:
151        model_info = modelinfo.make_model_info(kernel_module)
152        model = make_model_from_info(model_info)
153
154    # If a model name already exists and we are loading a different model,
155    # use the model file name as the model name.
156    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
157        _previous_name = model.name
158        model.name = model.id
159
160        # If the new model name is still in the model list (for instance,
161        # if we put a cylinder.py in our plug-in directory), then append
162        # an identifier.
163        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
164            model.name = model.id + '_user'
165        logger.info("Model %s already exists: using %s [%s]",
166                    _previous_name, model.name, model.filename)
167
168    # Only update the model if the module has changed
169    if reloaded or model.name not in MODELS:
170        MODELS[model.name] = model
171
172    return MODELS[model.name]
173
174
175def make_model_from_info(model_info):
176    # type: (ModelInfo) -> SasviewModelType
177    """
178    Convert *model_info* into a SasView model wrapper.
179    """
180    def __init__(self, multiplicity=None):
181        SasviewModel.__init__(self, multiplicity=multiplicity)
182    attrs = _generate_model_attributes(model_info)
183    attrs['__init__'] = __init__
184    attrs['filename'] = model_info.filename
185    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
186    return ConstructedModel
187
188
189def _make_standard_model(name):
190    # type: (str) -> SasviewModelType
191    """
192    Load the sasview model defined by *name*.
193
194    *name* can be a standard model name or a path to a custom model.
195
196    Returns a class that can be used directly as a sasview model.
197    """
198    kernel_module = generate.load_kernel_module(name)
199    model_info = modelinfo.make_model_info(kernel_module)
200    return make_model_from_info(model_info)
201
202
203def _register_old_models():
204    # type: () -> None
205    """
206    Place the new models into sasview under the old names.
207
208    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
209    is available to the plugin modules.
210    """
211    import sys
212    import sas   # needed in order to set sas.models
213    import sas.sascalc.fit
214    sys.modules['sas.models'] = sas.sascalc.fit
215    sas.models = sas.sascalc.fit
216    import sas.models
217    from sasmodels.conversion_table import CONVERSION_TABLE
218
219    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
220        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
221        new_name = new_name.split(':')[0]
222        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
223        module_attrs = {old_name: find_model(new_name)}
224        ConstructedModule = type(old_name, (), module_attrs)
225        old_path = 'sas.models.' + old_name
226        setattr(sas.models, old_path, ConstructedModule)
227        sys.modules[old_path] = ConstructedModule
228
229
230def MultiplicationModel(form_factor, structure_factor):
231    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
232    """
233    Returns a constructed product model from form_factor and structure_factor.
234    """
235    model_info = product.make_product_info(form_factor._model_info,
236                                           structure_factor._model_info)
237    ConstructedModel = make_model_from_info(model_info)
238    return ConstructedModel(form_factor.multiplicity)
239
240
241def _generate_model_attributes(model_info):
242    # type: (ModelInfo) -> Dict[str, Any]
243    """
244    Generate the class attributes for the model.
245
246    This should include all the information necessary to query the model
247    details so that you do not need to instantiate a model to query it.
248
249    All the attributes should be immutable to avoid accidents.
250    """
251
252    # TODO: allow model to override axis labels input/output name/unit
253
254    # Process multiplicity
255    non_fittable = []  # type: List[str]
256    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
257    variants = MultiplicityInfo(0, "", [], xlabel)
258    for p in model_info.parameters.kernel_parameters:
259        if p.name == model_info.control:
260            non_fittable.append(p.name)
261            variants = MultiplicityInfo(
262                len(p.choices) if p.choices else int(p.limits[1]),
263                p.name, p.choices, xlabel
264            )
265            break
266
267    # Only a single drop-down list parameter available
268    fun_list = []
269    for p in model_info.parameters.kernel_parameters:
270        if p.choices:
271            fun_list = p.choices
272            if p.length > 1:
273                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
274            break
275
276    # Organize parameter sets
277    orientation_params = []
278    magnetic_params = []
279    fixed = []
280    for p in model_info.parameters.user_parameters({}, is2d=True):
281        if p.type == 'orientation':
282            orientation_params.append(p.name)
283            orientation_params.append(p.name+".width")
284            fixed.append(p.name+".width")
285        elif p.type == 'magnetic':
286            orientation_params.append(p.name)
287            magnetic_params.append(p.name)
288            fixed.append(p.name+".width")
289
290
291    # Build class dictionary
292    attrs = {}  # type: Dict[str, Any]
293    attrs['_model_info'] = model_info
294    attrs['name'] = model_info.name
295    attrs['id'] = model_info.id
296    attrs['description'] = model_info.description
297    attrs['category'] = model_info.category
298    attrs['is_structure_factor'] = model_info.structure_factor
299    attrs['is_form_factor'] = model_info.ER is not None
300    attrs['is_multiplicity_model'] = variants[0] > 1
301    attrs['multiplicity_info'] = variants
302    attrs['orientation_params'] = tuple(orientation_params)
303    attrs['magnetic_params'] = tuple(magnetic_params)
304    attrs['fixed'] = tuple(fixed)
305    attrs['non_fittable'] = tuple(non_fittable)
306    attrs['fun_list'] = tuple(fun_list)
307
308    return attrs
309
310class SasviewModel(object):
311    """
312    Sasview wrapper for opencl/ctypes model.
313    """
314    # Model parameters for the specific model are set in the class constructor
315    # via the _generate_model_attributes function, which subclasses
316    # SasviewModel.  They are included here for typing and documentation
317    # purposes.
318    _model = None       # type: KernelModel
319    _model_info = None  # type: ModelInfo
320    #: load/save name for the model
321    id = None           # type: str
322    #: display name for the model
323    name = None         # type: str
324    #: short model description
325    description = None  # type: str
326    #: default model category
327    category = None     # type: str
328
329    #: names of the orientation parameters in the order they appear
330    orientation_params = None # type: List[str]
331    #: names of the magnetic parameters in the order they appear
332    magnetic_params = None    # type: List[str]
333    #: names of the fittable parameters
334    fixed = None              # type: List[str]
335    # TODO: the attribute fixed is ill-named
336
337    # Axis labels
338    input_name = "Q"
339    input_unit = "A^{-1}"
340    output_name = "Intensity"
341    output_unit = "cm^{-1}"
342
343    #: default cutoff for polydispersity
344    cutoff = 1e-5
345
346    # Note: Use non-mutable values for class attributes to avoid errors
347    #: parameters that are not fitted
348    non_fittable = ()        # type: Sequence[str]
349
350    #: True if model should appear as a structure factor
351    is_structure_factor = False
352    #: True if model should appear as a form factor
353    is_form_factor = False
354    #: True if model has multiplicity
355    is_multiplicity_model = False
356    #: Multiplicity information
357    multiplicity_info = None # type: MultiplicityInfoType
358
359    # Per-instance variables
360    #: parameter {name: value} mapping
361    params = None      # type: Dict[str, float]
362    #: values for dispersion width, npts, nsigmas and type
363    dispersion = None  # type: Dict[str, Any]
364    #: units and limits for each parameter
365    details = None     # type: Dict[str, Sequence[Any]]
366    #                  # actual type is Dict[str, List[str, float, float]]
367    #: multiplicity value, or None if no multiplicity on the model
368    multiplicity = None     # type: Optional[int]
369    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
370    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
371
372    def __init__(self, multiplicity=None):
373        # type: (Optional[int]) -> None
374
375        # TODO: _persistency_dict to persistency_dict throughout sasview
376        # TODO: refactor multiplicity to encompass variants
377        # TODO: dispersion should be a class
378        # TODO: refactor multiplicity info
379        # TODO: separate profile view from multiplicity
380        # The button label, x and y axis labels and scale need to be under
381        # the control of the model, not the fit page.  Maximum flexibility,
382        # the fit page would supply the canvas and the profile could plot
383        # how it wants, but this assumes matplotlib.  Next level is that
384        # we provide some sort of data description including title, labels
385        # and lines to plot.
386
387        # Get the list of hidden parameters given the multiplicity
388        # Don't include multiplicity in the list of parameters
389        self.multiplicity = multiplicity
390        if multiplicity is not None:
391            hidden = self._model_info.get_hidden_parameters(multiplicity)
392            hidden |= set([self.multiplicity_info.control])
393        else:
394            hidden = set()
395        if self._model_info.structure_factor:
396            hidden.add('scale')
397            hidden.add('background')
398            self._model_info.parameters.defaults['background'] = 0.
399
400        # Update the parameter lists to exclude any hidden parameters
401        self.magnetic_params = tuple(pname for pname in self.magnetic_params
402                                     if pname not in hidden)
403        self.orientation_params = tuple(pname for pname in self.orientation_params
404                                        if pname not in hidden)
405
406        self._persistency_dict = {}
407        self.params = collections.OrderedDict()
408        self.dispersion = collections.OrderedDict()
409        self.details = {}
410        for p in self._model_info.parameters.user_parameters({}, is2d=True):
411            if p.name in hidden:
412                continue
413            self.params[p.name] = p.default
414            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
415            if p.polydisperse:
416                self.details[p.id+".width"] = [
417                    "", 0.0, 1.0 if p.relative_pd else np.inf
418                ]
419                self.dispersion[p.name] = {
420                    'width': 0,
421                    'npts': 35,
422                    'nsigmas': 3,
423                    'type': 'gaussian',
424                }
425
426    def __get_state__(self):
427        # type: () -> Dict[str, Any]
428        state = self.__dict__.copy()
429        state.pop('_model')
430        # May need to reload model info on set state since it has pointers
431        # to python implementations of Iq, etc.
432        #state.pop('_model_info')
433        return state
434
435    def __set_state__(self, state):
436        # type: (Dict[str, Any]) -> None
437        self.__dict__ = state
438        self._model = None
439
440    def __str__(self):
441        # type: () -> str
442        """
443        :return: string representation
444        """
445        return self.name
446
447    def is_fittable(self, par_name):
448        # type: (str) -> bool
449        """
450        Check if a given parameter is fittable or not
451
452        :param par_name: the parameter name to check
453        """
454        return par_name in self.fixed
455        #For the future
456        #return self.params[str(par_name)].is_fittable()
457
458
459    def getProfile(self):
460        # type: () -> (np.ndarray, np.ndarray)
461        """
462        Get SLD profile
463
464        : return: (z, beta) where z is a list of depth of the transition points
465                beta is a list of the corresponding SLD values
466        """
467        args = {} # type: Dict[str, Any]
468        for p in self._model_info.parameters.kernel_parameters:
469            if p.id == self.multiplicity_info.control:
470                value = float(self.multiplicity)
471            elif p.length == 1:
472                value = self.params.get(p.id, np.NaN)
473            else:
474                value = np.array([self.params.get(p.id+str(k), np.NaN)
475                                  for k in range(1, p.length+1)])
476            args[p.id] = value
477
478        x, y = self._model_info.profile(**args)
479        return x, 1e-6*y
480
481    def setParam(self, name, value):
482        # type: (str, float) -> None
483        """
484        Set the value of a model parameter
485
486        :param name: name of the parameter
487        :param value: value of the parameter
488
489        """
490        # Look for dispersion parameters
491        toks = name.split('.')
492        if len(toks) == 2:
493            for item in self.dispersion.keys():
494                if item == toks[0]:
495                    for par in self.dispersion[item]:
496                        if par == toks[1]:
497                            self.dispersion[item][par] = value
498                            return
499        else:
500            # Look for standard parameter
501            for item in self.params.keys():
502                if item == name:
503                    self.params[item] = value
504                    return
505
506        raise ValueError("Model does not contain parameter %s" % name)
507
508    def getParam(self, name):
509        # type: (str) -> float
510        """
511        Set the value of a model parameter
512
513        :param name: name of the parameter
514
515        """
516        # Look for dispersion parameters
517        toks = name.split('.')
518        if len(toks) == 2:
519            for item in self.dispersion.keys():
520                if item == toks[0]:
521                    for par in self.dispersion[item]:
522                        if par == toks[1]:
523                            return self.dispersion[item][par]
524        else:
525            # Look for standard parameter
526            for item in self.params.keys():
527                if item == name:
528                    return self.params[item]
529
530        raise ValueError("Model does not contain parameter %s" % name)
531
532    def getParamList(self):
533        # type: () -> Sequence[str]
534        """
535        Return a list of all available parameters for the model
536        """
537        param_list = list(self.params.keys())
538        # WARNING: Extending the list with the dispersion parameters
539        param_list.extend(self.getDispParamList())
540        return param_list
541
542    def getDispParamList(self):
543        # type: () -> Sequence[str]
544        """
545        Return a list of polydispersity parameters for the model
546        """
547        # TODO: fix test so that parameter order doesn't matter
548        ret = ['%s.%s' % (p_name, ext)
549               for p_name in self.dispersion.keys()
550               for ext in ('npts', 'nsigmas', 'width')]
551        #print(ret)
552        return ret
553
554    def clone(self):
555        # type: () -> "SasviewModel"
556        """ Return a identical copy of self """
557        return deepcopy(self)
558
559    def run(self, x=0.0):
560        # type: (Union[float, (float, float), List[float]]) -> float
561        """
562        Evaluate the model
563
564        :param x: input q, or [q,phi]
565
566        :return: scattering function P(q)
567
568        **DEPRECATED**: use calculate_Iq instead
569        """
570        if isinstance(x, (list, tuple)):
571            # pylint: disable=unpacking-non-sequence
572            q, phi = x
573            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
574        else:
575            return self.calculate_Iq([x])[0]
576
577
578    def runXY(self, x=0.0):
579        # type: (Union[float, (float, float), List[float]]) -> float
580        """
581        Evaluate the model in cartesian coordinates
582
583        :param x: input q, or [qx, qy]
584
585        :return: scattering function P(q)
586
587        **DEPRECATED**: use calculate_Iq instead
588        """
589        if isinstance(x, (list, tuple)):
590            return self.calculate_Iq([x[0]], [x[1]])[0]
591        else:
592            return self.calculate_Iq([x])[0]
593
594    def evalDistribution(self, qdist):
595        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
596        r"""
597        Evaluate a distribution of q-values.
598
599        :param qdist: array of q or a list of arrays [qx,qy]
600
601        * For 1D, a numpy array is expected as input
602
603        ::
604
605            evalDistribution(q)
606
607          where *q* is a numpy array.
608
609        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
610
611        ::
612
613              qx = [ qx[0], qx[1], qx[2], ....]
614              qy = [ qy[0], qy[1], qy[2], ....]
615
616        If the model is 1D only, then
617
618        .. math::
619
620            q = \sqrt{q_x^2+q_y^2}
621
622        """
623        if isinstance(qdist, (list, tuple)):
624            # Check whether we have a list of ndarrays [qx,qy]
625            qx, qy = qdist
626            return self.calculate_Iq(qx, qy)
627
628        elif isinstance(qdist, np.ndarray):
629            # We have a simple 1D distribution of q-values
630            return self.calculate_Iq(qdist)
631
632        else:
633            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
634                            % type(qdist))
635
636    def calc_composition_models(self, qx):
637        """
638        returns parts of the composition model or None if not a composition
639        model.
640        """
641        # TODO: have calculate_Iq return the intermediates.
642        #
643        # The current interface causes calculate_Iq() to be called twice,
644        # once to get the combined result and again to get the intermediate
645        # results.  This is necessary for now.
646        # Long term, the solution is to change the interface to calculate_Iq
647        # so that it returns a results object containing all the bits:
648        #     the A, B, C, ... of the composition model (and any subcomponents?)
649        #     the P and S of the product model,
650        #     the combined model before resolution smearing,
651        #     the sasmodel before sesans conversion,
652        #     the oriented 2D model used to fit oriented usans data,
653        #     the final I(q),
654        #     ...
655        #
656        # Have the model calculator add all of these blindly to the data
657        # tree, and update the graphs which contain them.  The fitter
658        # needs to be updated to use the I(q) value only, ignoring the rest.
659        #
660        # The simple fix of returning the existing intermediate results
661        # will not work for a couple of reasons: (1) another thread may
662        # sneak in to compute its own results before calc_composition_models
663        # is called, and (2) calculate_Iq is currently called three times:
664        # once with q, once with q values before qmin and once with q values
665        # after q max.  Both of these should be addressed before
666        # replacing this code.
667        composition = self._model_info.composition
668        if composition and composition[0] == 'product': # only P*S for now
669            with calculation_lock:
670                self._calculate_Iq(qx)
671                return self._intermediate_results
672        else:
673            return None
674
675    def calculate_Iq(self, qx, qy=None):
676        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
677        """
678        Calculate Iq for one set of q with the current parameters.
679
680        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
681
682        This should NOT be used for fitting since it copies the *q* vectors
683        to the card for each evaluation.
684        """
685        ## uncomment the following when trying to debug the uncoordinated calls
686        ## to calculate_Iq
687        #if calculation_lock.locked():
688        #    logger.info("calculation waiting for another thread to complete")
689        #    logger.info("\n".join(traceback.format_stack()))
690
691        with calculation_lock:
692            return self._calculate_Iq(qx, qy)
693
694    def _calculate_Iq(self, qx, qy=None):
695        if self._model is None:
696            # Only need one copy of the compiled kernel regardless of how many
697            # times it is used, so store it in the class.  Also, to reset the
698            # compute engine, need to clear out all existing compiled kernels,
699            # which is much easier to do if we store them in the class.
700            self.__class__._model = core.build_model(self._model_info)
701        if qy is not None:
702            q_vectors = [np.asarray(qx), np.asarray(qy)]
703        else:
704            q_vectors = [np.asarray(qx)]
705        calculator = self._model.make_kernel(q_vectors)
706        parameters = self._model_info.parameters
707        pairs = [self._get_weights(p) for p in parameters.call_parameters]
708        #weights.plot_weights(self._model_info, pairs)
709        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
710        #call_details.show()
711        #print("================ parameters ==================")
712        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
713        #for k, p in enumerate(self._model_info.parameters.call_parameters):
714        #    print(k, p.name, *pairs[k])
715        #print("params", self.params)
716        #print("values", values)
717        #print("is_mag", is_magnetic)
718        result = calculator(call_details, values, cutoff=self.cutoff,
719                            magnetic=is_magnetic)
720        #print("result", result)
721        self._intermediate_results = getattr(calculator, 'results', None)
722        calculator.release()
723        #self._model.release()
724        return result
725
726    def calculate_ER(self):
727        # type: () -> float
728        """
729        Calculate the effective radius for P(q)*S(q)
730
731        :return: the value of the effective radius
732        """
733        if self._model_info.ER is None:
734            return 1.0
735        else:
736            value, weight = self._dispersion_mesh()
737            fv = self._model_info.ER(*value)
738            #print(values[0].shape, weights.shape, fv.shape)
739            return np.sum(weight * fv) / np.sum(weight)
740
741    def calculate_VR(self):
742        # type: () -> float
743        """
744        Calculate the volf ratio for P(q)*S(q)
745
746        :return: the value of the volf ratio
747        """
748        if self._model_info.VR is None:
749            return 1.0
750        else:
751            value, weight = self._dispersion_mesh()
752            whole, part = self._model_info.VR(*value)
753            return np.sum(weight * part) / np.sum(weight * whole)
754
755    def set_dispersion(self, parameter, dispersion):
756        # type: (str, weights.Dispersion) -> None
757        """
758        Set the dispersion object for a model parameter
759
760        :param parameter: name of the parameter [string]
761        :param dispersion: dispersion object of type Dispersion
762        """
763        if parameter in self.params:
764            # TODO: Store the disperser object directly in the model.
765            # The current method of relying on the sasview GUI to
766            # remember them is kind of funky.
767            # Note: can't seem to get disperser parameters from sasview
768            # (1) Could create a sasview model that has not yet been
769            # converted, assign the disperser to one of its polydisperse
770            # parameters, then retrieve the disperser parameters from the
771            # sasview model.
772            # (2) Could write a disperser parameter retriever in sasview.
773            # (3) Could modify sasview to use sasmodels.weights dispersers.
774            # For now, rely on the fact that the sasview only ever uses
775            # new dispersers in the set_dispersion call and create a new
776            # one instead of trying to assign parameters.
777            self.dispersion[parameter] = dispersion.get_pars()
778        else:
779            raise ValueError("%r is not a dispersity or orientation parameter"
780                             % parameter)
781
782    def _dispersion_mesh(self):
783        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
784        """
785        Create a mesh grid of dispersion parameters and weights.
786
787        Returns [p1,p2,...],w where pj is a vector of values for parameter j
788        and w is a vector containing the products for weights for each
789        parameter set in the vector.
790        """
791        pars = [self._get_weights(p)
792                for p in self._model_info.parameters.call_parameters
793                if p.type == 'volume']
794        return dispersion_mesh(self._model_info, pars)
795
796    def _get_weights(self, par):
797        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
798        """
799        Return dispersion weights for parameter
800        """
801        if par.name not in self.params:
802            if par.name == self.multiplicity_info.control:
803                return self.multiplicity, [self.multiplicity], [1.0]
804            else:
805                # For hidden parameters use default values.  This sets
806                # scale=1 and background=0 for structure factors
807                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
808                return default, [default], [1.0]
809        elif par.polydisperse:
810            value = self.params[par.name]
811            dis = self.dispersion[par.name]
812            if dis['type'] == 'array':
813                dispersity, weight = dis['values'], dis['weights']
814            else:
815                dispersity, weight = weights.get_weights(
816                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
817                    value, par.limits, par.relative_pd)
818            return value, dispersity, weight
819        else:
820            value = self.params[par.name]
821            return value, [value], [1.0]
822
823    @classmethod
824    def runTests(cls):
825        """
826        Run any tests built into the model and captures the test output.
827
828        Returns success flag and output
829        """
830        from .model_test import check_model
831        return check_model(cls._model_info)
832
833def test_cylinder():
834    # type: () -> float
835    """
836    Test that the cylinder model runs, returning the value at [0.1,0.1].
837    """
838    Cylinder = _make_standard_model('cylinder')
839    cylinder = Cylinder()
840    return cylinder.evalDistribution([0.1, 0.1])
841
842def test_structure_factor():
843    # type: () -> float
844    """
845    Test that 2-D hardsphere model runs and doesn't produce NaN.
846    """
847    Model = _make_standard_model('hardsphere')
848    model = Model()
849    value2d = model.evalDistribution([0.1, 0.1])
850    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
851    #print("hardsphere", value1d, value2d)
852    if np.isnan(value1d) or np.isnan(value2d):
853        raise ValueError("hardsphere returns nan")
854
855def test_product():
856    # type: () -> float
857    """
858    Test that 2-D hardsphere model runs and doesn't produce NaN.
859    """
860    S = _make_standard_model('hayter_msa')()
861    P = _make_standard_model('cylinder')()
862    model = MultiplicationModel(P, S)
863    value = model.evalDistribution([0.1, 0.1])
864    if np.isnan(value):
865        raise ValueError("cylinder*hatyer_msa returns null")
866
867def test_rpa():
868    # type: () -> float
869    """
870    Test that the 2-D RPA model runs
871    """
872    RPA = _make_standard_model('rpa')
873    rpa = RPA(3)
874    return rpa.evalDistribution([0.1, 0.1])
875
876def test_empty_distribution():
877    # type: () -> None
878    """
879    Make sure that sasmodels returns NaN when there are no polydispersity points
880    """
881    Cylinder = _make_standard_model('cylinder')
882    cylinder = Cylinder()
883    cylinder.setParam('radius', -1.0)
884    cylinder.setParam('background', 0.)
885    Iq = cylinder.evalDistribution(np.asarray([0.1]))
886    assert Iq[0] == 0., "empty distribution fails"
887
888def test_model_list():
889    # type: () -> None
890    """
891    Make sure that all models build as sasview models
892    """
893    from .exception import annotate_exception
894    for name in core.list_models():
895        try:
896            _make_standard_model(name)
897        except:
898            annotate_exception("when loading "+name)
899            raise
900
901def test_old_name():
902    # type: () -> None
903    """
904    Load and run cylinder model as sas-models-CylinderModel
905    """
906    if not SUPPORT_OLD_STYLE_PLUGINS:
907        return
908    try:
909        # if sasview is not on the path then don't try to test it
910        import sas
911    except ImportError:
912        return
913    load_standard_models()
914    from sas.models.CylinderModel import CylinderModel
915    CylinderModel().evalDistribution([0.1, 0.1])
916
917def magnetic_demo():
918    Model = _make_standard_model('sphere')
919    model = Model()
920    model.setParam('sld_M0', 8)
921    q = np.linspace(-0.35, 0.35, 500)
922    qx, qy = np.meshgrid(q, q)
923    result = model.calculate_Iq(qx.flatten(), qy.flatten())
924    result = result.reshape(qx.shape)
925
926    import pylab
927    pylab.imshow(np.log(result + 0.001))
928    pylab.show()
929
930if __name__ == "__main__":
931    print("cylinder(0.1,0.1)=%g"%test_cylinder())
932    #magnetic_demo()
933    #test_product()
934    #test_structure_factor()
935    #print("rpa:", test_rpa())
936    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.