source: sasmodels/sasmodels/sasview_model.py @ 9150036

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 9150036 was 9150036, checked in by Paul Kienzle <pkienzle@…>, 6 months ago

Merge branch 'master' into beta_approx

  • Property mode set to 100644
File size: 35.7 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import kernelcl
28from . import product
29from . import generate
30from . import weights
31from . import modelinfo
32from .details import make_kernel_args, dispersion_mesh
33from .kernelcl import reset_environment
34
35# pylint: disable=unused-import
36try:
37    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
38                        List, Optional, Union, Callable)
39    from .modelinfo import ModelInfo, Parameter
40    from .kernel import KernelModel
41    MultiplicityInfoType = NamedTuple(
42        'MultiplicityInfo',
43        [("number", int), ("control", str), ("choices", List[str]),
44         ("x_axis_label", str)])
45    SasviewModelType = Callable[[int], "SasviewModel"]
46except ImportError:
47    pass
48# pylint: enable=unused-import
49
50logger = logging.getLogger(__name__)
51
52calculation_lock = thread.allocate_lock()
53
54#: True if pre-existing plugins, with the old names and parameters, should
55#: continue to be supported.
56SUPPORT_OLD_STYLE_PLUGINS = True
57
58# TODO: separate x_axis_label from multiplicity info
59MultiplicityInfo = collections.namedtuple(
60    'MultiplicityInfo',
61    ["number", "control", "choices", "x_axis_label"],
62)
63
64#: set of defined models (standard and custom)
65MODELS = {}  # type: Dict[str, SasviewModelType]
66# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
67#: custom model {path: model} mapping so we can check timestamps
68MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
69#: Track modules that we have loaded so we can determine whether the model
70#: has changed since we last reloaded.
71_CACHED_MODULE = {}  # type: Dict[str, "module"]
72
73def reset_environment():
74    # type: () -> None
75    """
76    Clear the compute engine context so that the GUI can change devices.
77
78    This removes all compiled kernels, even those that are active on fit
79    pages, but they will be restored the next time they are needed.
80    """
81    kernelcl.reset_environment()
82    for model in MODELS.values():
83        model._model = None
84
85def find_model(modelname):
86    # type: (str) -> SasviewModelType
87    """
88    Find a model by name.  If the model name ends in py, try loading it from
89    custom models, otherwise look for it in the list of builtin models.
90    """
91    # TODO: used by sum/product model to load an existing model
92    # TODO: doesn't handle custom models properly
93    if modelname.endswith('.py'):
94        return load_custom_model(modelname)
95    elif modelname in MODELS:
96        return MODELS[modelname]
97    else:
98        raise ValueError("unknown model %r"%modelname)
99
100
101# TODO: figure out how to say that the return type is a subclass
102def load_standard_models():
103    # type: () -> List[SasviewModelType]
104    """
105    Load and return the list of predefined models.
106
107    If there is an error loading a model, then a traceback is logged and the
108    model is not returned.
109    """
110    for name in core.list_models():
111        try:
112            MODELS[name] = _make_standard_model(name)
113        except Exception:
114            logger.error(traceback.format_exc())
115    if SUPPORT_OLD_STYLE_PLUGINS:
116        _register_old_models()
117
118    return list(MODELS.values())
119
120
121def load_custom_model(path):
122    # type: (str) -> SasviewModelType
123    """
124    Load a custom model given the model path.
125    """
126    #logger.info("Loading model %s", path)
127
128    # Load the kernel module.  This may already be cached by the loader, so
129    # only requires checking the timestamps of the dependents.
130    kernel_module = custom.load_custom_kernel_module(path)
131
132    # Check if the module has changed since we last looked.
133    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
134    _CACHED_MODULE[path] = kernel_module
135
136    # Turn the module into a model.  We need to do this in even if the
137    # model has already been loaded so that we can determine the model
138    # name and retrieve it from the MODELS cache.
139    model = getattr(kernel_module, 'Model', None)
140    if model is not None:
141        # Old style models do not set the name in the class attributes, so
142        # set it here; this name will be overridden when the object is created
143        # with an instance variable that has the same value.
144        if model.name == "":
145            model.name = splitext(basename(path))[0]
146        if not hasattr(model, 'filename'):
147            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
148        if not hasattr(model, 'id'):
149            model.id = splitext(basename(model.filename))[0]
150    else:
151        model_info = modelinfo.make_model_info(kernel_module)
152        model = make_model_from_info(model_info)
153
154    # If a model name already exists and we are loading a different model,
155    # use the model file name as the model name.
156    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
157        _previous_name = model.name
158        model.name = model.id
159
160        # If the new model name is still in the model list (for instance,
161        # if we put a cylinder.py in our plug-in directory), then append
162        # an identifier.
163        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
164            model.name = model.id + '_user'
165        logger.info("Model %s already exists: using %s [%s]",
166                    _previous_name, model.name, model.filename)
167
168    # Only update the model if the module has changed
169    if reloaded or model.name not in MODELS:
170        MODELS[model.name] = model
171
172    return MODELS[model.name]
173
174
175def make_model_from_info(model_info):
176    # type: (ModelInfo) -> SasviewModelType
177    """
178    Convert *model_info* into a SasView model wrapper.
179    """
180    def __init__(self, multiplicity=None):
181        SasviewModel.__init__(self, multiplicity=multiplicity)
182    attrs = _generate_model_attributes(model_info)
183    attrs['__init__'] = __init__
184    attrs['filename'] = model_info.filename
185    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
186    return ConstructedModel
187
188
189def _make_standard_model(name):
190    # type: (str) -> SasviewModelType
191    """
192    Load the sasview model defined by *name*.
193
194    *name* can be a standard model name or a path to a custom model.
195
196    Returns a class that can be used directly as a sasview model.
197    """
198    kernel_module = generate.load_kernel_module(name)
199    model_info = modelinfo.make_model_info(kernel_module)
200    return make_model_from_info(model_info)
201
202
203def _register_old_models():
204    # type: () -> None
205    """
206    Place the new models into sasview under the old names.
207
208    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
209    is available to the plugin modules.
210    """
211    import sys
212    import sas   # needed in order to set sas.models
213    import sas.sascalc.fit
214    sys.modules['sas.models'] = sas.sascalc.fit
215    sas.models = sas.sascalc.fit
216    import sas.models
217    from sasmodels.conversion_table import CONVERSION_TABLE
218
219    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
220        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
221        new_name = new_name.split(':')[0]
222        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
223        module_attrs = {old_name: find_model(new_name)}
224        ConstructedModule = type(old_name, (), module_attrs)
225        old_path = 'sas.models.' + old_name
226        setattr(sas.models, old_path, ConstructedModule)
227        sys.modules[old_path] = ConstructedModule
228
229
230def MultiplicationModel(form_factor, structure_factor):
231    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
232    """
233    Returns a constructed product model from form_factor and structure_factor.
234    """
235    model_info = product.make_product_info(form_factor._model_info,
236                                           structure_factor._model_info)
237    ConstructedModel = make_model_from_info(model_info)
238    return ConstructedModel(form_factor.multiplicity)
239
240
241def _generate_model_attributes(model_info):
242    # type: (ModelInfo) -> Dict[str, Any]
243    """
244    Generate the class attributes for the model.
245
246    This should include all the information necessary to query the model
247    details so that you do not need to instantiate a model to query it.
248
249    All the attributes should be immutable to avoid accidents.
250    """
251
252    # TODO: allow model to override axis labels input/output name/unit
253
254    # Process multiplicity
255    non_fittable = []  # type: List[str]
256    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
257    variants = MultiplicityInfo(0, "", [], xlabel)
258    for p in model_info.parameters.kernel_parameters:
259        if p.name == model_info.control:
260            non_fittable.append(p.name)
261            variants = MultiplicityInfo(
262                len(p.choices) if p.choices else int(p.limits[1]),
263                p.name, p.choices, xlabel
264            )
265            break
266
267    # Only a single drop-down list parameter available
268    fun_list = []
269    for p in model_info.parameters.kernel_parameters:
270        if p.choices:
271            fun_list = p.choices
272            if p.length > 1:
273                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
274            break
275
276    # Organize parameter sets
277    orientation_params = []
278    magnetic_params = []
279    fixed = []
280    for p in model_info.parameters.user_parameters({}, is2d=True):
281        if p.type == 'orientation':
282            orientation_params.append(p.name)
283            orientation_params.append(p.name+".width")
284            fixed.append(p.name+".width")
285        elif p.type == 'magnetic':
286            orientation_params.append(p.name)
287            magnetic_params.append(p.name)
288            fixed.append(p.name+".width")
289
290
291    # Build class dictionary
292    attrs = {}  # type: Dict[str, Any]
293    attrs['_model_info'] = model_info
294    attrs['name'] = model_info.name
295    attrs['id'] = model_info.id
296    attrs['description'] = model_info.description
297    attrs['category'] = model_info.category
298    attrs['is_structure_factor'] = model_info.structure_factor
299    attrs['is_form_factor'] = model_info.effective_radius_type is not None
300    attrs['is_multiplicity_model'] = variants[0] > 1
301    attrs['multiplicity_info'] = variants
302    attrs['orientation_params'] = tuple(orientation_params)
303    attrs['magnetic_params'] = tuple(magnetic_params)
304    attrs['fixed'] = tuple(fixed)
305    attrs['non_fittable'] = tuple(non_fittable)
306    attrs['fun_list'] = tuple(fun_list)
307
308    return attrs
309
310class SasviewModel(object):
311    """
312    Sasview wrapper for opencl/ctypes model.
313    """
314    # Model parameters for the specific model are set in the class constructor
315    # via the _generate_model_attributes function, which subclasses
316    # SasviewModel.  They are included here for typing and documentation
317    # purposes.
318    _model = None       # type: KernelModel
319    _model_info = None  # type: ModelInfo
320    #: load/save name for the model
321    id = None           # type: str
322    #: display name for the model
323    name = None         # type: str
324    #: short model description
325    description = None  # type: str
326    #: default model category
327    category = None     # type: str
328
329    #: names of the orientation parameters in the order they appear
330    orientation_params = None # type: List[str]
331    #: names of the magnetic parameters in the order they appear
332    magnetic_params = None    # type: List[str]
333    #: names of the fittable parameters
334    fixed = None              # type: List[str]
335    # TODO: the attribute fixed is ill-named
336
337    # Axis labels
338    input_name = "Q"
339    input_unit = "A^{-1}"
340    output_name = "Intensity"
341    output_unit = "cm^{-1}"
342
343    #: default cutoff for polydispersity
344    cutoff = 1e-5
345
346    # Note: Use non-mutable values for class attributes to avoid errors
347    #: parameters that are not fitted
348    non_fittable = ()        # type: Sequence[str]
349
350    #: True if model should appear as a structure factor
351    is_structure_factor = False
352    #: True if model should appear as a form factor
353    is_form_factor = False
354    #: True if model has multiplicity
355    is_multiplicity_model = False
356    #: Multiplicity information
357    multiplicity_info = None # type: MultiplicityInfoType
358
359    # Per-instance variables
360    #: parameter {name: value} mapping
361    params = None      # type: Dict[str, float]
362    #: values for dispersion width, npts, nsigmas and type
363    dispersion = None  # type: Dict[str, Any]
364    #: units and limits for each parameter
365    details = None     # type: Dict[str, Sequence[Any]]
366    #                  # actual type is Dict[str, List[str, float, float]]
367    #: multiplicity value, or None if no multiplicity on the model
368    multiplicity = None     # type: Optional[int]
369    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
370    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
371
372    def __init__(self, multiplicity=None):
373        # type: (Optional[int]) -> None
374
375        # TODO: _persistency_dict to persistency_dict throughout sasview
376        # TODO: refactor multiplicity to encompass variants
377        # TODO: dispersion should be a class
378        # TODO: refactor multiplicity info
379        # TODO: separate profile view from multiplicity
380        # The button label, x and y axis labels and scale need to be under
381        # the control of the model, not the fit page.  Maximum flexibility,
382        # the fit page would supply the canvas and the profile could plot
383        # how it wants, but this assumes matplotlib.  Next level is that
384        # we provide some sort of data description including title, labels
385        # and lines to plot.
386
387        # Get the list of hidden parameters given the multiplicity
388        # Don't include multiplicity in the list of parameters
389        self.multiplicity = multiplicity
390        if multiplicity is not None:
391            hidden = self._model_info.get_hidden_parameters(multiplicity)
392            hidden |= set([self.multiplicity_info.control])
393        else:
394            hidden = set()
395        if self._model_info.structure_factor:
396            hidden.add('scale')
397            hidden.add('background')
398
399        # Update the parameter lists to exclude any hidden parameters
400        self.magnetic_params = tuple(pname for pname in self.magnetic_params
401                                     if pname not in hidden)
402        self.orientation_params = tuple(pname for pname in self.orientation_params
403                                        if pname not in hidden)
404
405        self._persistency_dict = {}
406        self.params = collections.OrderedDict()
407        self.dispersion = collections.OrderedDict()
408        self.details = {}
409        for p in self._model_info.parameters.user_parameters({}, is2d=True):
410            if p.name in hidden:
411                continue
412            self.params[p.name] = p.default
413            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
414            if p.polydisperse:
415                self.details[p.id+".width"] = [
416                    "", 0.0, 1.0 if p.relative_pd else np.inf
417                ]
418                self.dispersion[p.name] = {
419                    'width': 0,
420                    'npts': 35,
421                    'nsigmas': 3,
422                    'type': 'gaussian',
423                }
424
425    def __get_state__(self):
426        # type: () -> Dict[str, Any]
427        state = self.__dict__.copy()
428        state.pop('_model')
429        # May need to reload model info on set state since it has pointers
430        # to python implementations of Iq, etc.
431        #state.pop('_model_info')
432        return state
433
434    def __set_state__(self, state):
435        # type: (Dict[str, Any]) -> None
436        self.__dict__ = state
437        self._model = None
438
439    def __str__(self):
440        # type: () -> str
441        """
442        :return: string representation
443        """
444        return self.name
445
446    def is_fittable(self, par_name):
447        # type: (str) -> bool
448        """
449        Check if a given parameter is fittable or not
450
451        :param par_name: the parameter name to check
452        """
453        return par_name in self.fixed
454        #For the future
455        #return self.params[str(par_name)].is_fittable()
456
457
458    def getProfile(self):
459        # type: () -> (np.ndarray, np.ndarray)
460        """
461        Get SLD profile
462
463        : return: (z, beta) where z is a list of depth of the transition points
464                beta is a list of the corresponding SLD values
465        """
466        args = {} # type: Dict[str, Any]
467        for p in self._model_info.parameters.kernel_parameters:
468            if p.id == self.multiplicity_info.control:
469                value = float(self.multiplicity)
470            elif p.length == 1:
471                value = self.params.get(p.id, np.NaN)
472            else:
473                value = np.array([self.params.get(p.id+str(k), np.NaN)
474                                  for k in range(1, p.length+1)])
475            args[p.id] = value
476
477        x, y = self._model_info.profile(**args)
478        return x, 1e-6*y
479
480    def setParam(self, name, value):
481        # type: (str, float) -> None
482        """
483        Set the value of a model parameter
484
485        :param name: name of the parameter
486        :param value: value of the parameter
487
488        """
489        # Look for dispersion parameters
490        toks = name.split('.')
491        if len(toks) == 2:
492            for item in self.dispersion.keys():
493                if item == toks[0]:
494                    for par in self.dispersion[item]:
495                        if par == toks[1]:
496                            self.dispersion[item][par] = value
497                            return
498        else:
499            # Look for standard parameter
500            for item in self.params.keys():
501                if item == name:
502                    self.params[item] = value
503                    return
504
505        raise ValueError("Model does not contain parameter %s" % name)
506
507    def getParam(self, name):
508        # type: (str) -> float
509        """
510        Set the value of a model parameter
511
512        :param name: name of the parameter
513
514        """
515        # Look for dispersion parameters
516        toks = name.split('.')
517        if len(toks) == 2:
518            for item in self.dispersion.keys():
519                if item == toks[0]:
520                    for par in self.dispersion[item]:
521                        if par == toks[1]:
522                            return self.dispersion[item][par]
523        else:
524            # Look for standard parameter
525            for item in self.params.keys():
526                if item == name:
527                    return self.params[item]
528
529        raise ValueError("Model does not contain parameter %s" % name)
530
531    def getParamList(self):
532        # type: () -> Sequence[str]
533        """
534        Return a list of all available parameters for the model
535        """
536        param_list = list(self.params.keys())
537        # WARNING: Extending the list with the dispersion parameters
538        param_list.extend(self.getDispParamList())
539        return param_list
540
541    def getDispParamList(self):
542        # type: () -> Sequence[str]
543        """
544        Return a list of polydispersity parameters for the model
545        """
546        # TODO: fix test so that parameter order doesn't matter
547        ret = ['%s.%s' % (p_name, ext)
548               for p_name in self.dispersion.keys()
549               for ext in ('npts', 'nsigmas', 'width')]
550        #print(ret)
551        return ret
552
553    def clone(self):
554        # type: () -> "SasviewModel"
555        """ Return a identical copy of self """
556        return deepcopy(self)
557
558    def run(self, x=0.0):
559        # type: (Union[float, (float, float), List[float]]) -> float
560        """
561        Evaluate the model
562
563        :param x: input q, or [q,phi]
564
565        :return: scattering function P(q)
566
567        **DEPRECATED**: use calculate_Iq instead
568        """
569        if isinstance(x, (list, tuple)):
570            # pylint: disable=unpacking-non-sequence
571            q, phi = x
572            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
573            return result[0]
574        else:
575            result, _ = self.calculate_Iq([x])
576            return result[0]
577
578
579    def runXY(self, x=0.0):
580        # type: (Union[float, (float, float), List[float]]) -> float
581        """
582        Evaluate the model in cartesian coordinates
583
584        :param x: input q, or [qx, qy]
585
586        :return: scattering function P(q)
587
588        **DEPRECATED**: use calculate_Iq instead
589        """
590        if isinstance(x, (list, tuple)):
591            result, _ = self.calculate_Iq([x[0]], [x[1]])
592            return result[0]
593        else:
594            result, _ = self.calculate_Iq([x])
595            return result[0]
596
597    def evalDistribution(self, qdist):
598        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
599        r"""
600        Evaluate a distribution of q-values.
601
602        :param qdist: array of q or a list of arrays [qx,qy]
603
604        * For 1D, a numpy array is expected as input
605
606        ::
607
608            evalDistribution(q)
609
610          where *q* is a numpy array.
611
612        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
613
614        ::
615
616              qx = [ qx[0], qx[1], qx[2], ....]
617              qy = [ qy[0], qy[1], qy[2], ....]
618
619        If the model is 1D only, then
620
621        .. math::
622
623            q = \sqrt{q_x^2+q_y^2}
624
625        """
626        if isinstance(qdist, (list, tuple)):
627            # Check whether we have a list of ndarrays [qx,qy]
628            qx, qy = qdist
629            result, _ = self.calculate_Iq(qx, qy)
630            return result
631
632        elif isinstance(qdist, np.ndarray):
633            # We have a simple 1D distribution of q-values
634            result, _ = self.calculate_Iq(qdist)
635            return result
636
637        else:
638            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
639                            % type(qdist))
640
641    def calc_composition_models(self, qx):
642        """
643        returns parts of the composition model or None if not a composition
644        model.
645        """
646        # TODO: have calculate_Iq return the intermediates.
647        #
648        # The current interface causes calculate_Iq() to be called twice,
649        # once to get the combined result and again to get the intermediate
650        # results.  This is necessary for now.
651        # Long term, the solution is to change the interface to calculate_Iq
652        # so that it returns a results object containing all the bits:
653        #     the A, B, C, ... of the composition model (and any subcomponents?)
654        #     the P and S of the product model
655        #     the combined model before resolution smearing,
656        #     the sasmodel before sesans conversion,
657        #     the oriented 2D model used to fit oriented usans data,
658        #     the final I(q),
659        #     ...
660        #
661        # Have the model calculator add all of these blindly to the data
662        # tree, and update the graphs which contain them.  The fitter
663        # needs to be updated to use the I(q) value only, ignoring the rest.
664        #
665        # The simple fix of returning the existing intermediate results
666        # will not work for a couple of reasons: (1) another thread may
667        # sneak in to compute its own results before calc_composition_models
668        # is called, and (2) calculate_Iq is currently called three times:
669        # once with q, once with q values before qmin and once with q values
670        # after q max.  Both of these should be addressed before
671        # replacing this code.
672        composition = self._model_info.composition
673        if composition and composition[0] == 'product': # only P*S for now
674            with calculation_lock:
675                _, lazy_results = self._calculate_Iq(qx)
676                # for compatibility with sasview 4.x
677                results = lazy_results()
678                pq_data = results.get("P(Q)")
679                sq_data = results.get("S(Q)")
680                return pq_data, sq_data
681        else:
682            return None
683
684    def calculate_Iq(self,
685                     qx,     # type: Sequence[float]
686                     qy=None # type: Optional[Sequence[float]]
687                     ):
688        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
689        """
690        Calculate Iq for one set of q with the current parameters.
691
692        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
693
694        This should NOT be used for fitting since it copies the *q* vectors
695        to the card for each evaluation.
696
697        The returned tuple contains the scattering intensity followed by a
698        callable which returns a dictionary of intermediate data from
699        ProductKernel.
700        """
701        ## uncomment the following when trying to debug the uncoordinated calls
702        ## to calculate_Iq
703        #if calculation_lock.locked():
704        #    logger.info("calculation waiting for another thread to complete")
705        #    logger.info("\n".join(traceback.format_stack()))
706
707        with calculation_lock:
708            return self._calculate_Iq(qx, qy)
709
710    def _calculate_Iq(self, qx, qy=None):
711        if self._model is None:
712            # Only need one copy of the compiled kernel regardless of how many
713            # times it is used, so store it in the class.  Also, to reset the
714            # compute engine, need to clear out all existing compiled kernels,
715            # which is much easier to do if we store them in the class.
716            self.__class__._model = core.build_model(self._model_info)
717        if qy is not None:
718            q_vectors = [np.asarray(qx), np.asarray(qy)]
719        else:
720            q_vectors = [np.asarray(qx)]
721        calculator = self._model.make_kernel(q_vectors)
722        parameters = self._model_info.parameters
723        pairs = [self._get_weights(p) for p in parameters.call_parameters]
724        #weights.plot_weights(self._model_info, pairs)
725        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
726        #call_details.show()
727        #print("================ parameters ==================")
728        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
729        #for k, p in enumerate(self._model_info.parameters.call_parameters):
730        #    print(k, p.name, *pairs[k])
731        #print("params", self.params)
732        #print("values", values)
733        #print("is_mag", is_magnetic)
734        result = calculator(call_details, values, cutoff=self.cutoff,
735                            magnetic=is_magnetic)
736        lazy_results = getattr(calculator, 'results',
737                               lambda: collections.OrderedDict())
738        #print("result", result)
739
740        calculator.release()
741        #self._model.release()
742
743        return result, lazy_results
744
745
746    def calculate_ER(self, mode=1):
747        # type: () -> float
748        """
749        Calculate the effective radius for P(q)*S(q)
750
751        *mode* is the R_eff type, which defaults to 1 to match the ER
752        calculation for sasview models from version 3.x.
753
754        :return: the value of the effective radius
755        """
756        # ER and VR are only needed for old multiplication models, based on
757        # sas.sascalc.fit.MultiplicationModel.  Fail for now.  If we want to
758        # continue supporting them then add some test cases so that the code
759        # is exercised.  We can access ER/VR using the kernel Fq function by
760        # extending _calculate_Iq so that it calls:
761        #    if er_mode > 0:
762        #        res = calculator.Fq(call_details, values, cutoff=self.cutoff,
763        #                            magnetic=False, effective_radius_type=mode)
764        #        R_eff, form_shell_ratio = res[2], res[4]
765        #        return R_eff, form_shell_ratio
766        # Then use the following in calculate_ER:
767        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=mode)
768        #    return ER
769        # Similarly, for calculate_VR:
770        #    ER, VR = self._calculate_Iq(q=[0.1], er_mode=1)
771        #    return VR
772        # Obviously a combined calculate_ER_VR method would be better, but
773        # we only need them to support very old models, so ignore the 2x
774        # performance hit.
775        raise NotImplementedError("ER function is no longer available.")
776
777    def calculate_VR(self):
778        # type: () -> float
779        """
780        Calculate the volf ratio for P(q)*S(q)
781
782        :return: the value of the form:shell volume ratio
783        """
784        # See comments in calculate_ER.
785        raise NotImplementedError("VR function is no longer available.")
786
787    def set_dispersion(self, parameter, dispersion):
788        # type: (str, weights.Dispersion) -> None
789        """
790        Set the dispersion object for a model parameter
791
792        :param parameter: name of the parameter [string]
793        :param dispersion: dispersion object of type Dispersion
794        """
795        if parameter in self.params:
796            # TODO: Store the disperser object directly in the model.
797            # The current method of relying on the sasview GUI to
798            # remember them is kind of funky.
799            # Note: can't seem to get disperser parameters from sasview
800            # (1) Could create a sasview model that has not yet been
801            # converted, assign the disperser to one of its polydisperse
802            # parameters, then retrieve the disperser parameters from the
803            # sasview model.
804            # (2) Could write a disperser parameter retriever in sasview.
805            # (3) Could modify sasview to use sasmodels.weights dispersers.
806            # For now, rely on the fact that the sasview only ever uses
807            # new dispersers in the set_dispersion call and create a new
808            # one instead of trying to assign parameters.
809            self.dispersion[parameter] = dispersion.get_pars()
810        else:
811            raise ValueError("%r is not a dispersity or orientation parameter"
812                             % parameter)
813
814    def _dispersion_mesh(self):
815        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
816        """
817        Create a mesh grid of dispersion parameters and weights.
818
819        Returns [p1,p2,...],w where pj is a vector of values for parameter j
820        and w is a vector containing the products for weights for each
821        parameter set in the vector.
822        """
823        pars = [self._get_weights(p)
824                for p in self._model_info.parameters.call_parameters
825                if p.type == 'volume']
826        return dispersion_mesh(self._model_info, pars)
827
828    def _get_weights(self, par):
829        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
830        """
831        Return dispersion weights for parameter
832        """
833        if par.name not in self.params:
834            if par.name == self.multiplicity_info.control:
835                return self.multiplicity, [self.multiplicity], [1.0]
836            else:
837                # For hidden parameters use default values.  This sets
838                # scale=1 and background=0 for structure factors
839                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
840                return default, [default], [1.0]
841        elif par.polydisperse:
842            value = self.params[par.name]
843            dis = self.dispersion[par.name]
844            if dis['type'] == 'array':
845                dispersity, weight = dis['values'], dis['weights']
846            else:
847                dispersity, weight = weights.get_weights(
848                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
849                    value, par.limits, par.relative_pd)
850            return value, dispersity, weight
851        else:
852            value = self.params[par.name]
853            return value, [value], [1.0]
854
855    @classmethod
856    def runTests(cls):
857        """
858        Run any tests built into the model and captures the test output.
859
860        Returns success flag and output
861        """
862        from .model_test import check_model
863        return check_model(cls._model_info)
864
865def test_cylinder():
866    # type: () -> float
867    """
868    Test that the cylinder model runs, returning the value at [0.1,0.1].
869    """
870    Cylinder = _make_standard_model('cylinder')
871    cylinder = Cylinder()
872    return cylinder.evalDistribution([0.1, 0.1])
873
874def test_structure_factor():
875    # type: () -> float
876    """
877    Test that 2-D hardsphere model runs and doesn't produce NaN.
878    """
879    Model = _make_standard_model('hardsphere')
880    model = Model()
881    value2d = model.evalDistribution([0.1, 0.1])
882    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
883    #print("hardsphere", value1d, value2d)
884    if np.isnan(value1d) or np.isnan(value2d):
885        raise ValueError("hardsphere returns nan")
886
887def test_product():
888    # type: () -> float
889    """
890    Test that 2-D hardsphere model runs and doesn't produce NaN.
891    """
892    S = _make_standard_model('hayter_msa')()
893    P = _make_standard_model('cylinder')()
894    model = MultiplicationModel(P, S)
895    model.setParam(product.RADIUS_MODE_ID, 1.0)
896    value = model.evalDistribution([0.1, 0.1])
897    if np.isnan(value):
898        raise ValueError("cylinder*hatyer_msa returns null")
899
900def test_rpa():
901    # type: () -> float
902    """
903    Test that the 2-D RPA model runs
904    """
905    RPA = _make_standard_model('rpa')
906    rpa = RPA(3)
907    return rpa.evalDistribution([0.1, 0.1])
908
909def test_empty_distribution():
910    # type: () -> None
911    """
912    Make sure that sasmodels returns NaN when there are no polydispersity points
913    """
914    Cylinder = _make_standard_model('cylinder')
915    cylinder = Cylinder()
916    cylinder.setParam('radius', -1.0)
917    cylinder.setParam('background', 0.)
918    Iq = cylinder.evalDistribution(np.asarray([0.1]))
919    assert Iq[0] == 0., "empty distribution fails"
920
921def test_model_list():
922    # type: () -> None
923    """
924    Make sure that all models build as sasview models
925    """
926    from .exception import annotate_exception
927    for name in core.list_models():
928        try:
929            _make_standard_model(name)
930        except:
931            annotate_exception("when loading "+name)
932            raise
933
934def test_old_name():
935    # type: () -> None
936    """
937    Load and run cylinder model as sas-models-CylinderModel
938    """
939    if not SUPPORT_OLD_STYLE_PLUGINS:
940        return
941    try:
942        # if sasview is not on the path then don't try to test it
943        import sas
944    except ImportError:
945        return
946    load_standard_models()
947    from sas.models.CylinderModel import CylinderModel
948    CylinderModel().evalDistribution([0.1, 0.1])
949
950def test_structure_factor_background():
951    # type: () -> None
952    """
953    Check that sasview model and direct model match, with background=0.
954    """
955    from .data import empty_data1D
956    from .core import load_model_info, build_model
957    from .direct_model import DirectModel
958
959    model_name = "hardsphere"
960    q = [0.0]
961
962    sasview_model = _make_standard_model(model_name)()
963    sasview_value = sasview_model.evalDistribution(np.array(q))[0]
964
965    data = empty_data1D(q)
966    model_info = load_model_info(model_name)
967    model = build_model(model_info)
968    direct_model = DirectModel(data, model)
969    direct_value_zero_background = direct_model(background=0.0)
970
971    assert sasview_value == direct_value_zero_background
972
973    # Additionally check that direct value background defaults to zero
974    direct_value_default = direct_model()
975    assert sasview_value == direct_value_default
976
977
978def magnetic_demo():
979    Model = _make_standard_model('sphere')
980    model = Model()
981    model.setParam('sld_M0', 8)
982    q = np.linspace(-0.35, 0.35, 500)
983    qx, qy = np.meshgrid(q, q)
984    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
985    result = result.reshape(qx.shape)
986
987    import pylab
988    pylab.imshow(np.log(result + 0.001))
989    pylab.show()
990
991if __name__ == "__main__":
992    print("cylinder(0.1,0.1)=%g"%test_cylinder())
993    #magnetic_demo()
994    #test_product()
995    #test_structure_factor()
996    #print("rpa:", test_rpa())
997    #test_empty_distribution()
998    #test_structure_factor_background()
Note: See TracBrowser for help on using the repository browser.