source: sasmodels/sasmodels/sasview_model.py @ 4e96703

Last change on this file since 4e96703 was 4e96703, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

Merge branch 'beta_approx' into ticket-608-user-defined-weights

  • Property mode set to 100644
File size: 33.3 KB
Line 
1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
9"""
10from __future__ import print_function
11
12import math
13from copy import deepcopy
14import collections
15import traceback
16import logging
17from os.path import basename, splitext, abspath, getmtime
18try:
19    import _thread as thread
20except ImportError:
21    import thread
22
23import numpy as np  # type: ignore
24
25from . import core
26from . import custom
27from . import product
28from . import generate
29from . import weights
30from . import modelinfo
31from .details import make_kernel_args, dispersion_mesh
32
33# Hack: load in any custom distributions
34# Uses ~/.sasview/weights/*.py unless SASMODELS_WEIGHTS is set in the environ.
35# Override with weights.load_weights(pattern="<weights_path>/*.py")
36weights.load_weights()
37
38# pylint: disable=unused-import
39try:
40    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
41                        List, Optional, Union, Callable)
42    from .modelinfo import ModelInfo, Parameter
43    from .kernel import KernelModel
44    MultiplicityInfoType = NamedTuple(
45        'MultiplicityInfo',
46        [("number", int), ("control", str), ("choices", List[str]),
47         ("x_axis_label", str)])
48    SasviewModelType = Callable[[int], "SasviewModel"]
49except ImportError:
50    pass
51# pylint: enable=unused-import
52
53logger = logging.getLogger(__name__)
54
55calculation_lock = thread.allocate_lock()
56
57#: True if pre-existing plugins, with the old names and parameters, should
58#: continue to be supported.
59SUPPORT_OLD_STYLE_PLUGINS = True
60
61# TODO: separate x_axis_label from multiplicity info
62MultiplicityInfo = collections.namedtuple(
63    'MultiplicityInfo',
64    ["number", "control", "choices", "x_axis_label"],
65)
66
67#: set of defined models (standard and custom)
68MODELS = {}  # type: Dict[str, SasviewModelType]
69# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
70#: custom model {path: model} mapping so we can check timestamps
71MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
72#: Track modules that we have loaded so we can determine whether the model
73#: has changed since we last reloaded.
74_CACHED_MODULE = {}  # type: Dict[str, "module"]
75
76def find_model(modelname):
77    # type: (str) -> SasviewModelType
78    """
79    Find a model by name.  If the model name ends in py, try loading it from
80    custom models, otherwise look for it in the list of builtin models.
81    """
82    # TODO: used by sum/product model to load an existing model
83    # TODO: doesn't handle custom models properly
84    if modelname.endswith('.py'):
85        return load_custom_model(modelname)
86    elif modelname in MODELS:
87        return MODELS[modelname]
88    else:
89        raise ValueError("unknown model %r"%modelname)
90
91
92# TODO: figure out how to say that the return type is a subclass
93def load_standard_models():
94    # type: () -> List[SasviewModelType]
95    """
96    Load and return the list of predefined models.
97
98    If there is an error loading a model, then a traceback is logged and the
99    model is not returned.
100    """
101    for name in core.list_models():
102        try:
103            MODELS[name] = _make_standard_model(name)
104        except Exception:
105            logger.error(traceback.format_exc())
106    if SUPPORT_OLD_STYLE_PLUGINS:
107        _register_old_models()
108
109    return list(MODELS.values())
110
111
112def load_custom_model(path):
113    # type: (str) -> SasviewModelType
114    """
115    Load a custom model given the model path.
116    """
117    #logger.info("Loading model %s", path)
118
119    # Load the kernel module.  This may already be cached by the loader, so
120    # only requires checking the timestamps of the dependents.
121    kernel_module = custom.load_custom_kernel_module(path)
122
123    # Check if the module has changed since we last looked.
124    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
125    _CACHED_MODULE[path] = kernel_module
126
127    # Turn the module into a model.  We need to do this in even if the
128    # model has already been loaded so that we can determine the model
129    # name and retrieve it from the MODELS cache.
130    model = getattr(kernel_module, 'Model', None)
131    if model is not None:
132        # Old style models do not set the name in the class attributes, so
133        # set it here; this name will be overridden when the object is created
134        # with an instance variable that has the same value.
135        if model.name == "":
136            model.name = splitext(basename(path))[0]
137        if not hasattr(model, 'filename'):
138            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
139        if not hasattr(model, 'id'):
140            model.id = splitext(basename(model.filename))[0]
141    else:
142        model_info = modelinfo.make_model_info(kernel_module)
143        model = make_model_from_info(model_info)
144
145    # If a model name already exists and we are loading a different model,
146    # use the model file name as the model name.
147    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
148        _previous_name = model.name
149        model.name = model.id
150
151        # If the new model name is still in the model list (for instance,
152        # if we put a cylinder.py in our plug-in directory), then append
153        # an identifier.
154        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
155            model.name = model.id + '_user'
156        logger.info("Model %s already exists: using %s [%s]",
157                    _previous_name, model.name, model.filename)
158
159    # Only update the model if the module has changed
160    if reloaded or model.name not in MODELS:
161        MODELS[model.name] = model
162
163    return MODELS[model.name]
164
165
166def make_model_from_info(model_info):
167    # type: (ModelInfo) -> SasviewModelType
168    """
169    Convert *model_info* into a SasView model wrapper.
170    """
171    def __init__(self, multiplicity=None):
172        SasviewModel.__init__(self, multiplicity=multiplicity)
173    attrs = _generate_model_attributes(model_info)
174    attrs['__init__'] = __init__
175    attrs['filename'] = model_info.filename
176    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
177    return ConstructedModel
178
179
180def _make_standard_model(name):
181    # type: (str) -> SasviewModelType
182    """
183    Load the sasview model defined by *name*.
184
185    *name* can be a standard model name or a path to a custom model.
186
187    Returns a class that can be used directly as a sasview model.
188    """
189    kernel_module = generate.load_kernel_module(name)
190    model_info = modelinfo.make_model_info(kernel_module)
191    return make_model_from_info(model_info)
192
193
194def _register_old_models():
195    # type: () -> None
196    """
197    Place the new models into sasview under the old names.
198
199    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
200    is available to the plugin modules.
201    """
202    import sys
203    import sas   # needed in order to set sas.models
204    import sas.sascalc.fit
205    sys.modules['sas.models'] = sas.sascalc.fit
206    sas.models = sas.sascalc.fit
207    import sas.models
208    from sasmodels.conversion_table import CONVERSION_TABLE
209
210    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
211        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
212        new_name = new_name.split(':')[0]
213        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
214        module_attrs = {old_name: find_model(new_name)}
215        ConstructedModule = type(old_name, (), module_attrs)
216        old_path = 'sas.models.' + old_name
217        setattr(sas.models, old_path, ConstructedModule)
218        sys.modules[old_path] = ConstructedModule
219
220
221def MultiplicationModel(form_factor, structure_factor):
222    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
223    """
224    Returns a constructed product model from form_factor and structure_factor.
225    """
226    model_info = product.make_product_info(form_factor._model_info,
227                                           structure_factor._model_info)
228    ConstructedModel = make_model_from_info(model_info)
229    return ConstructedModel(form_factor.multiplicity)
230
231
232def _generate_model_attributes(model_info):
233    # type: (ModelInfo) -> Dict[str, Any]
234    """
235    Generate the class attributes for the model.
236
237    This should include all the information necessary to query the model
238    details so that you do not need to instantiate a model to query it.
239
240    All the attributes should be immutable to avoid accidents.
241    """
242
243    # TODO: allow model to override axis labels input/output name/unit
244
245    # Process multiplicity
246    non_fittable = []  # type: List[str]
247    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
248    variants = MultiplicityInfo(0, "", [], xlabel)
249    for p in model_info.parameters.kernel_parameters:
250        if p.name == model_info.control:
251            non_fittable.append(p.name)
252            variants = MultiplicityInfo(
253                len(p.choices) if p.choices else int(p.limits[1]),
254                p.name, p.choices, xlabel
255            )
256            break
257
258    # Only a single drop-down list parameter available
259    fun_list = []
260    for p in model_info.parameters.kernel_parameters:
261        if p.choices:
262            fun_list = p.choices
263            if p.length > 1:
264                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
265            break
266
267    # Organize parameter sets
268    orientation_params = []
269    magnetic_params = []
270    fixed = []
271    for p in model_info.parameters.user_parameters({}, is2d=True):
272        if p.type == 'orientation':
273            orientation_params.append(p.name)
274            orientation_params.append(p.name+".width")
275            fixed.append(p.name+".width")
276        elif p.type == 'magnetic':
277            orientation_params.append(p.name)
278            magnetic_params.append(p.name)
279            fixed.append(p.name+".width")
280
281
282    # Build class dictionary
283    attrs = {}  # type: Dict[str, Any]
284    attrs['_model_info'] = model_info
285    attrs['name'] = model_info.name
286    attrs['id'] = model_info.id
287    attrs['description'] = model_info.description
288    attrs['category'] = model_info.category
289    attrs['is_structure_factor'] = model_info.structure_factor
290    attrs['is_form_factor'] = model_info.effective_radius_type is not None
291    attrs['is_multiplicity_model'] = variants[0] > 1
292    attrs['multiplicity_info'] = variants
293    attrs['orientation_params'] = tuple(orientation_params)
294    attrs['magnetic_params'] = tuple(magnetic_params)
295    attrs['fixed'] = tuple(fixed)
296    attrs['non_fittable'] = tuple(non_fittable)
297    attrs['fun_list'] = tuple(fun_list)
298
299    return attrs
300
301class SasviewModel(object):
302    """
303    Sasview wrapper for opencl/ctypes model.
304    """
305    # Model parameters for the specific model are set in the class constructor
306    # via the _generate_model_attributes function, which subclasses
307    # SasviewModel.  They are included here for typing and documentation
308    # purposes.
309    _model = None       # type: KernelModel
310    _model_info = None  # type: ModelInfo
311    #: load/save name for the model
312    id = None           # type: str
313    #: display name for the model
314    name = None         # type: str
315    #: short model description
316    description = None  # type: str
317    #: default model category
318    category = None     # type: str
319
320    #: names of the orientation parameters in the order they appear
321    orientation_params = None # type: List[str]
322    #: names of the magnetic parameters in the order they appear
323    magnetic_params = None    # type: List[str]
324    #: names of the fittable parameters
325    fixed = None              # type: List[str]
326    # TODO: the attribute fixed is ill-named
327
328    # Axis labels
329    input_name = "Q"
330    input_unit = "A^{-1}"
331    output_name = "Intensity"
332    output_unit = "cm^{-1}"
333
334    #: default cutoff for polydispersity
335    cutoff = 1e-5
336
337    # Note: Use non-mutable values for class attributes to avoid errors
338    #: parameters that are not fitted
339    non_fittable = ()        # type: Sequence[str]
340
341    #: True if model should appear as a structure factor
342    is_structure_factor = False
343    #: True if model should appear as a form factor
344    is_form_factor = False
345    #: True if model has multiplicity
346    is_multiplicity_model = False
347    #: Multiplicity information
348    multiplicity_info = None # type: MultiplicityInfoType
349
350    # Per-instance variables
351    #: parameter {name: value} mapping
352    params = None      # type: Dict[str, float]
353    #: values for dispersion width, npts, nsigmas and type
354    dispersion = None  # type: Dict[str, Any]
355    #: units and limits for each parameter
356    details = None     # type: Dict[str, Sequence[Any]]
357    #                  # actual type is Dict[str, List[str, float, float]]
358    #: multiplicity value, or None if no multiplicity on the model
359    multiplicity = None     # type: Optional[int]
360    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
361    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
362
363    def __init__(self, multiplicity=None):
364        # type: (Optional[int]) -> None
365
366        # TODO: _persistency_dict to persistency_dict throughout sasview
367        # TODO: refactor multiplicity to encompass variants
368        # TODO: dispersion should be a class
369        # TODO: refactor multiplicity info
370        # TODO: separate profile view from multiplicity
371        # The button label, x and y axis labels and scale need to be under
372        # the control of the model, not the fit page.  Maximum flexibility,
373        # the fit page would supply the canvas and the profile could plot
374        # how it wants, but this assumes matplotlib.  Next level is that
375        # we provide some sort of data description including title, labels
376        # and lines to plot.
377
378        # Get the list of hidden parameters given the multiplicity
379        # Don't include multiplicity in the list of parameters
380        self.multiplicity = multiplicity
381        if multiplicity is not None:
382            hidden = self._model_info.get_hidden_parameters(multiplicity)
383            hidden |= set([self.multiplicity_info.control])
384        else:
385            hidden = set()
386        if self._model_info.structure_factor:
387            hidden.add('scale')
388            hidden.add('background')
389            self._model_info.parameters.defaults['background'] = 0.
390
391        # Update the parameter lists to exclude any hidden parameters
392        self.magnetic_params = tuple(pname for pname in self.magnetic_params
393                                     if pname not in hidden)
394        self.orientation_params = tuple(pname for pname in self.orientation_params
395                                        if pname not in hidden)
396
397        self._persistency_dict = {}
398        self.params = collections.OrderedDict()
399        self.dispersion = collections.OrderedDict()
400        self.details = {}
401        for p in self._model_info.parameters.user_parameters({}, is2d=True):
402            if p.name in hidden:
403                continue
404            self.params[p.name] = p.default
405            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
406            if p.polydisperse:
407                self.details[p.id+".width"] = [
408                    "", 0.0, 1.0 if p.relative_pd else np.inf
409                ]
410                self.dispersion[p.name] = {
411                    'width': 0,
412                    'npts': 35,
413                    'nsigmas': 3,
414                    'type': 'gaussian',
415                }
416
417    def __get_state__(self):
418        # type: () -> Dict[str, Any]
419        state = self.__dict__.copy()
420        state.pop('_model')
421        # May need to reload model info on set state since it has pointers
422        # to python implementations of Iq, etc.
423        #state.pop('_model_info')
424        return state
425
426    def __set_state__(self, state):
427        # type: (Dict[str, Any]) -> None
428        self.__dict__ = state
429        self._model = None
430
431    def __str__(self):
432        # type: () -> str
433        """
434        :return: string representation
435        """
436        return self.name
437
438    def is_fittable(self, par_name):
439        # type: (str) -> bool
440        """
441        Check if a given parameter is fittable or not
442
443        :param par_name: the parameter name to check
444        """
445        return par_name in self.fixed
446        #For the future
447        #return self.params[str(par_name)].is_fittable()
448
449
450    def getProfile(self):
451        # type: () -> (np.ndarray, np.ndarray)
452        """
453        Get SLD profile
454
455        : return: (z, beta) where z is a list of depth of the transition points
456                beta is a list of the corresponding SLD values
457        """
458        args = {} # type: Dict[str, Any]
459        for p in self._model_info.parameters.kernel_parameters:
460            if p.id == self.multiplicity_info.control:
461                value = float(self.multiplicity)
462            elif p.length == 1:
463                value = self.params.get(p.id, np.NaN)
464            else:
465                value = np.array([self.params.get(p.id+str(k), np.NaN)
466                                  for k in range(1, p.length+1)])
467            args[p.id] = value
468
469        x, y = self._model_info.profile(**args)
470        return x, 1e-6*y
471
472    def setParam(self, name, value):
473        # type: (str, float) -> None
474        """
475        Set the value of a model parameter
476
477        :param name: name of the parameter
478        :param value: value of the parameter
479
480        """
481        # Look for dispersion parameters
482        toks = name.split('.')
483        if len(toks) == 2:
484            for item in self.dispersion.keys():
485                if item == toks[0]:
486                    for par in self.dispersion[item]:
487                        if par == toks[1]:
488                            self.dispersion[item][par] = value
489                            return
490        else:
491            # Look for standard parameter
492            for item in self.params.keys():
493                if item == name:
494                    self.params[item] = value
495                    return
496
497        raise ValueError("Model does not contain parameter %s" % name)
498
499    def getParam(self, name):
500        # type: (str) -> float
501        """
502        Set the value of a model parameter
503
504        :param name: name of the parameter
505
506        """
507        # Look for dispersion parameters
508        toks = name.split('.')
509        if len(toks) == 2:
510            for item in self.dispersion.keys():
511                if item == toks[0]:
512                    for par in self.dispersion[item]:
513                        if par == toks[1]:
514                            return self.dispersion[item][par]
515        else:
516            # Look for standard parameter
517            for item in self.params.keys():
518                if item == name:
519                    return self.params[item]
520
521        raise ValueError("Model does not contain parameter %s" % name)
522
523    def getParamList(self):
524        # type: () -> Sequence[str]
525        """
526        Return a list of all available parameters for the model
527        """
528        param_list = list(self.params.keys())
529        # WARNING: Extending the list with the dispersion parameters
530        param_list.extend(self.getDispParamList())
531        return param_list
532
533    def getDispParamList(self):
534        # type: () -> Sequence[str]
535        """
536        Return a list of polydispersity parameters for the model
537        """
538        # TODO: fix test so that parameter order doesn't matter
539        ret = ['%s.%s' % (p_name, ext)
540               for p_name in self.dispersion.keys()
541               for ext in ('npts', 'nsigmas', 'width')]
542        #print(ret)
543        return ret
544
545    def clone(self):
546        # type: () -> "SasviewModel"
547        """ Return a identical copy of self """
548        return deepcopy(self)
549
550    def run(self, x=0.0):
551        # type: (Union[float, (float, float), List[float]]) -> float
552        """
553        Evaluate the model
554
555        :param x: input q, or [q,phi]
556
557        :return: scattering function P(q)
558
559        **DEPRECATED**: use calculate_Iq instead
560        """
561        if isinstance(x, (list, tuple)):
562            # pylint: disable=unpacking-non-sequence
563            q, phi = x
564            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
565            return result[0]
566        else:
567            result, _ = self.calculate_Iq([x])
568            return result[0]
569
570
571    def runXY(self, x=0.0):
572        # type: (Union[float, (float, float), List[float]]) -> float
573        """
574        Evaluate the model in cartesian coordinates
575
576        :param x: input q, or [qx, qy]
577
578        :return: scattering function P(q)
579
580        **DEPRECATED**: use calculate_Iq instead
581        """
582        if isinstance(x, (list, tuple)):
583            result, _ = self.calculate_Iq([x[0]], [x[1]])
584            return result[0]
585        else:
586            result, _ = self.calculate_Iq([x])
587            return result[0]
588
589    def evalDistribution(self, qdist):
590        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
591        r"""
592        Evaluate a distribution of q-values.
593
594        :param qdist: array of q or a list of arrays [qx,qy]
595
596        * For 1D, a numpy array is expected as input
597
598        ::
599
600            evalDistribution(q)
601
602          where *q* is a numpy array.
603
604        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
605
606        ::
607
608              qx = [ qx[0], qx[1], qx[2], ....]
609              qy = [ qy[0], qy[1], qy[2], ....]
610
611        If the model is 1D only, then
612
613        .. math::
614
615            q = \sqrt{q_x^2+q_y^2}
616
617        """
618        if isinstance(qdist, (list, tuple)):
619            # Check whether we have a list of ndarrays [qx,qy]
620            qx, qy = qdist
621            result, _ = self.calculate_Iq(qx, qy)
622            return result
623
624        elif isinstance(qdist, np.ndarray):
625            # We have a simple 1D distribution of q-values
626            result, _ = self.calculate_Iq(qdist)
627            return result
628
629        else:
630            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
631                            % type(qdist))
632
633    def calc_composition_models(self, qx):
634        """
635        returns parts of the composition model or None if not a composition
636        model.
637        """
638        # TODO: have calculate_Iq return the intermediates.
639        #
640        # The current interface causes calculate_Iq() to be called twice,
641        # once to get the combined result and again to get the intermediate
642        # results.  This is necessary for now.
643        # Long term, the solution is to change the interface to calculate_Iq
644        # so that it returns a results object containing all the bits:
645        #     the A, B, C, ... of the composition model (and any subcomponents?)
646        #     the P and S of the product model
647        #     the combined model before resolution smearing,
648        #     the sasmodel before sesans conversion,
649        #     the oriented 2D model used to fit oriented usans data,
650        #     the final I(q),
651        #     ...
652        #
653        # Have the model calculator add all of these blindly to the data
654        # tree, and update the graphs which contain them.  The fitter
655        # needs to be updated to use the I(q) value only, ignoring the rest.
656        #
657        # The simple fix of returning the existing intermediate results
658        # will not work for a couple of reasons: (1) another thread may
659        # sneak in to compute its own results before calc_composition_models
660        # is called, and (2) calculate_Iq is currently called three times:
661        # once with q, once with q values before qmin and once with q values
662        # after q max.  Both of these should be addressed before
663        # replacing this code.
664        composition = self._model_info.composition
665        if composition and composition[0] == 'product': # only P*S for now
666            with calculation_lock:
667                _, lazy_results = self._calculate_Iq(qx)
668                # for compatibility with sasview 4.x
669                results = lazy_results()
670                pq_data = results.get("P(Q)")
671                sq_data = results.get("S(Q)")
672                return pq_data, sq_data
673        else:
674            return None
675
676    def calculate_Iq(self,
677                     qx,     # type: Sequence[float]
678                     qy=None # type: Optional[Sequence[float]]
679                     ):
680        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
681        """
682        Calculate Iq for one set of q with the current parameters.
683
684        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
685
686        This should NOT be used for fitting since it copies the *q* vectors
687        to the card for each evaluation.
688
689        The returned tuple contains the scattering intensity followed by a
690        callable which returns a dictionary of intermediate data from
691        ProductKernel.
692        """
693        ## uncomment the following when trying to debug the uncoordinated calls
694        ## to calculate_Iq
695        #if calculation_lock.locked():
696        #    logger.info("calculation waiting for another thread to complete")
697        #    logger.info("\n".join(traceback.format_stack()))
698
699        with calculation_lock:
700            return self._calculate_Iq(qx, qy)
701
702    def _calculate_Iq(self, qx, qy=None, Fq=False, effective_radius_type=1):
703        if self._model is None:
704            self._model = core.build_model(self._model_info)
705        if qy is not None:
706            q_vectors = [np.asarray(qx), np.asarray(qy)]
707        else:
708            q_vectors = [np.asarray(qx)]
709        calculator = self._model.make_kernel(q_vectors)
710        parameters = self._model_info.parameters
711        pairs = [self._get_weights(p) for p in parameters.call_parameters]
712        #weights.plot_weights(self._model_info, pairs)
713        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
714        #call_details.show()
715        #print("================ parameters ==================")
716        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
717        #for k, p in enumerate(self._model_info.parameters.call_parameters):
718        #    print(k, p.name, *pairs[k])
719        #print("params", self.params)
720        #print("values", values)
721        #print("is_mag", is_magnetic)
722        if Fq:
723            result = calculator.Fq(call_details, values, cutoff=self.cutoff,
724                                   magnetic=is_magnetic,
725                                   effective_radius_type=effective_radius_type)
726        result = calculator(call_details, values, cutoff=self.cutoff,
727                            magnetic=is_magnetic)
728        lazy_results = getattr(calculator, 'results',
729                               lambda: collections.OrderedDict())
730        #print("result", result)
731
732        calculator.release()
733        #self._model.release()
734
735        return result, lazy_results
736
737
738    def calculate_ER(self, mode=1):
739        # type: () -> float
740        """
741        Calculate the effective radius for P(q)*S(q)
742
743        :return: the value of the effective radius
744        """
745        Fq = self._calculate_Iq([0.1], True, mode)
746        return Fq[2]
747
748    def calculate_VR(self):
749        # type: () -> float
750        """
751        Calculate the volf ratio for P(q)*S(q)
752
753        :return: the value of the form:shell volume ratio
754        """
755        Fq = self._calculate_Iq([0.1], True, mode)
756        return Fq[4]
757
758    def set_dispersion(self, parameter, dispersion):
759        # type: (str, weights.Dispersion) -> None
760        """
761        Set the dispersion object for a model parameter
762
763        :param parameter: name of the parameter [string]
764        :param dispersion: dispersion object of type Dispersion
765        """
766        if parameter in self.params:
767            # TODO: Store the disperser object directly in the model.
768            # The current method of relying on the sasview GUI to
769            # remember them is kind of funky.
770            # Note: can't seem to get disperser parameters from sasview
771            # (1) Could create a sasview model that has not yet been
772            # converted, assign the disperser to one of its polydisperse
773            # parameters, then retrieve the disperser parameters from the
774            # sasview model.
775            # (2) Could write a disperser parameter retriever in sasview.
776            # (3) Could modify sasview to use sasmodels.weights dispersers.
777            # For now, rely on the fact that the sasview only ever uses
778            # new dispersers in the set_dispersion call and create a new
779            # one instead of trying to assign parameters.
780            self.dispersion[parameter] = dispersion.get_pars()
781        else:
782            raise ValueError("%r is not a dispersity or orientation parameter"
783                             % parameter)
784
785    def _dispersion_mesh(self):
786        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
787        """
788        Create a mesh grid of dispersion parameters and weights.
789
790        Returns [p1,p2,...],w where pj is a vector of values for parameter j
791        and w is a vector containing the products for weights for each
792        parameter set in the vector.
793        """
794        pars = [self._get_weights(p)
795                for p in self._model_info.parameters.call_parameters
796                if p.type == 'volume']
797        return dispersion_mesh(self._model_info, pars)
798
799    def _get_weights(self, par):
800        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
801        """
802        Return dispersion weights for parameter
803        """
804        if par.name not in self.params:
805            if par.name == self.multiplicity_info.control:
806                return self.multiplicity, [self.multiplicity], [1.0]
807            else:
808                # For hidden parameters use default values.  This sets
809                # scale=1 and background=0 for structure factors
810                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
811                return default, [default], [1.0]
812        elif par.polydisperse:
813            value = self.params[par.name]
814            dis = self.dispersion[par.name]
815            if dis['type'] == 'array':
816                dispersity, weight = dis['values'], dis['weights']
817            else:
818                dispersity, weight = weights.get_weights(
819                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
820                    value, par.limits, par.relative_pd)
821            return value, dispersity, weight
822        else:
823            value = self.params[par.name]
824            return value, [value], [1.0]
825
826    @classmethod
827    def runTests(cls):
828        """
829        Run any tests built into the model and captures the test output.
830
831        Returns success flag and output
832        """
833        from .model_test import check_model
834        return check_model(cls._model_info)
835
836def test_cylinder():
837    # type: () -> float
838    """
839    Test that the cylinder model runs, returning the value at [0.1,0.1].
840    """
841    Cylinder = _make_standard_model('cylinder')
842    cylinder = Cylinder()
843    return cylinder.evalDistribution([0.1, 0.1])
844
845def test_structure_factor():
846    # type: () -> float
847    """
848    Test that 2-D hardsphere model runs and doesn't produce NaN.
849    """
850    Model = _make_standard_model('hardsphere')
851    model = Model()
852    value2d = model.evalDistribution([0.1, 0.1])
853    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
854    #print("hardsphere", value1d, value2d)
855    if np.isnan(value1d) or np.isnan(value2d):
856        raise ValueError("hardsphere returns nan")
857
858def test_product():
859    # type: () -> float
860    """
861    Test that 2-D hardsphere model runs and doesn't produce NaN.
862    """
863    S = _make_standard_model('hayter_msa')()
864    P = _make_standard_model('cylinder')()
865    model = MultiplicationModel(P, S)
866    model.setParam(product.RADIUS_MODE_ID, 1.0)
867    value = model.evalDistribution([0.1, 0.1])
868    if np.isnan(value):
869        raise ValueError("cylinder*hatyer_msa returns null")
870
871def test_rpa():
872    # type: () -> float
873    """
874    Test that the 2-D RPA model runs
875    """
876    RPA = _make_standard_model('rpa')
877    rpa = RPA(3)
878    return rpa.evalDistribution([0.1, 0.1])
879
880def test_empty_distribution():
881    # type: () -> None
882    """
883    Make sure that sasmodels returns NaN when there are no polydispersity points
884    """
885    Cylinder = _make_standard_model('cylinder')
886    cylinder = Cylinder()
887    cylinder.setParam('radius', -1.0)
888    cylinder.setParam('background', 0.)
889    Iq = cylinder.evalDistribution(np.asarray([0.1]))
890    assert Iq[0] == 0., "empty distribution fails"
891
892def test_model_list():
893    # type: () -> None
894    """
895    Make sure that all models build as sasview models
896    """
897    from .exception import annotate_exception
898    for name in core.list_models():
899        try:
900            _make_standard_model(name)
901        except:
902            annotate_exception("when loading "+name)
903            raise
904
905def test_old_name():
906    # type: () -> None
907    """
908    Load and run cylinder model as sas-models-CylinderModel
909    """
910    if not SUPPORT_OLD_STYLE_PLUGINS:
911        return
912    try:
913        # if sasview is not on the path then don't try to test it
914        import sas
915    except ImportError:
916        return
917    load_standard_models()
918    from sas.models.CylinderModel import CylinderModel
919    CylinderModel().evalDistribution([0.1, 0.1])
920
921def magnetic_demo():
922    Model = _make_standard_model('sphere')
923    model = Model()
924    model.setParam('sld_M0', 8)
925    q = np.linspace(-0.35, 0.35, 500)
926    qx, qy = np.meshgrid(q, q)
927    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
928    result = result.reshape(qx.shape)
929
930    import pylab
931    pylab.imshow(np.log(result + 0.001))
932    pylab.show()
933
934if __name__ == "__main__":
935    print("cylinder(0.1,0.1)=%g"%test_cylinder())
936    #magnetic_demo()
937    #test_product()
938    #test_structure_factor()
939    #print("rpa:", test_rpa())
940    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.