source: sasmodels/sasmodels/sasview_model.py @ 30b60d2

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 30b60d2 was 724257c, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

don't reload custom models in sasview unless the timestamp on the file has changed

  • Property mode set to 100644
File size: 28.2 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[a38b065]18import thread
[ce27e21]19
[7ae2b7f]20import numpy as np  # type: ignore
[ce27e21]21
[aa4946b]22from . import core
[4d76711]23from . import custom
[a80e64c]24from . import product
[72a081d]25from . import generate
[fb5914f]26from . import weights
[6d6508e]27from . import modelinfo
[bde38b5]28from .details import make_kernel_args, dispersion_mesh
[ff7119b]29
[fa5fd8d]30try:
[60f03de]31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
[60f03de]38    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]39except ImportError:
40    pass
41
[724257c]42logger = logging.getLogger(__name__)
43
[a38b065]44calculation_lock = thread.allocate_lock()
45
[724257c]46#: True if pre-existing plugins, with the old names and parameters, should
47#: continue to be supported.
[c95dfc63]48SUPPORT_OLD_STYLE_PLUGINS = True
49
[fa5fd8d]50# TODO: separate x_axis_label from multiplicity info
51MultiplicityInfo = collections.namedtuple(
52    'MultiplicityInfo',
53    ["number", "control", "choices", "x_axis_label"],
54)
55
[724257c]56#: set of defined models (standard and custom)
57MODELS = {}  # type: Dict[str, SasviewModelType]
58#: custom model {path: model} mapping so we can check timestamps
59MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
60
[92d38285]61def find_model(modelname):
[b32dafd]62    # type: (str) -> SasviewModelType
63    """
64    Find a model by name.  If the model name ends in py, try loading it from
65    custom models, otherwise look for it in the list of builtin models.
66    """
[92d38285]67    # TODO: used by sum/product model to load an existing model
68    # TODO: doesn't handle custom models properly
69    if modelname.endswith('.py'):
70        return load_custom_model(modelname)
71    elif modelname in MODELS:
72        return MODELS[modelname]
73    else:
74        raise ValueError("unknown model %r"%modelname)
75
[56b2687]76
[fa5fd8d]77# TODO: figure out how to say that the return type is a subclass
[4d76711]78def load_standard_models():
[60f03de]79    # type: () -> List[SasviewModelType]
[4d76711]80    """
81    Load and return the list of predefined models.
82
83    If there is an error loading a model, then a traceback is logged and the
84    model is not returned.
85    """
86    for name in core.list_models():
87        try:
[92d38285]88            MODELS[name] = _make_standard_model(name)
[ee8f734]89        except Exception:
[724257c]90            logger.error(traceback.format_exc())
[c95dfc63]91    if SUPPORT_OLD_STYLE_PLUGINS:
92        _register_old_models()
93
[724257c]94    return list(MODELS.values())
[de97440]95
[4d76711]96
97def load_custom_model(path):
[60f03de]98    # type: (str) -> SasviewModelType
[4d76711]99    """
100    Load a custom model given the model path.
[ff7119b]101    """
[724257c]102    model = MODEL_BY_PATH.get(path, None)
103    if model is not None and model.timestamp == getmtime(path):
104        #logger.info("Model already loaded %s", path)
105        return model
106
107    #logger.info("Loading model %s", path)
[4d76711]108    kernel_module = custom.load_custom_kernel_module(path)
[724257c]109    if hasattr(kernel_module, 'Model'):
[92d38285]110        model = kernel_module.Model
[9457498]111        # Old style models do not set the name in the class attributes, so
112        # set it here; this name will be overridden when the object is created
113        # with an instance variable that has the same value.
114        if model.name == "":
115            model.name = splitext(basename(path))[0]
[20a70bc]116        if not hasattr(model, 'filename'):
[724257c]117            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]118        if not hasattr(model, 'id'):
119            model.id = splitext(basename(model.filename))[0]
[724257c]120    else:
[56b2687]121        model_info = modelinfo.make_model_info(kernel_module)
[92d38285]122        model = _make_model_from_info(model_info)
[724257c]123    model.timestamp = getmtime(path)
[ed10b57]124
[2f2c70c]125    # If a model name already exists and we are loading a different model,
126    # use the model file name as the model name.
127    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
128        _previous_name = model.name
129        model.name = model.id
[724257c]130
[2f2c70c]131        # If the new model name is still in the model list (for instance,
132        # if we put a cylinder.py in our plug-in directory), then append
133        # an identifier.
134        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
135            model.name = model.id + '_user'
[724257c]136        logger.info("Model %s already exists: using %s [%s]",
137                    _previous_name, model.name, model.filename)
[ed10b57]138
[92d38285]139    MODELS[model.name] = model
[724257c]140    MODEL_BY_PATH[path] = model
[92d38285]141    return model
[4d76711]142
[87985ca]143
[4d76711]144def _make_standard_model(name):
[60f03de]145    # type: (str) -> SasviewModelType
[ff7119b]146    """
[4d76711]147    Load the sasview model defined by *name*.
[72a081d]148
[4d76711]149    *name* can be a standard model name or a path to a custom model.
[87985ca]150
[4d76711]151    Returns a class that can be used directly as a sasview model.
[ff7119b]152    """
[4d76711]153    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]154    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]155    return _make_model_from_info(model_info)
[72a081d]156
157
[724257c]158def _register_old_models():
159    # type: () -> None
160    """
161    Place the new models into sasview under the old names.
162
163    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
164    is available to the plugin modules.
165    """
166    import sys
167    import sas   # needed in order to set sas.models
168    import sas.sascalc.fit
169    sys.modules['sas.models'] = sas.sascalc.fit
170    sas.models = sas.sascalc.fit
171
172    import sas.models
173    from sasmodels.conversion_table import CONVERSION_TABLE
174    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
175        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
176        new_name = new_name.split(':')[0]
177        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
178        module_attrs = {old_name: find_model(new_name)}
179        ConstructedModule = type(old_name, (), module_attrs)
180        old_path = 'sas.models.' + old_name
181        setattr(sas.models, old_path, ConstructedModule)
182        sys.modules[old_path] = ConstructedModule
183
184
[a80e64c]185def MultiplicationModel(form_factor, structure_factor):
186    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
187    model_info = product.make_product_info(form_factor._model_info,
188                                           structure_factor._model_info)
189    ConstructedModel = _make_model_from_info(model_info)
190    return ConstructedModel()
191
[4d76711]192def _make_model_from_info(model_info):
[60f03de]193    # type: (ModelInfo) -> SasviewModelType
[4d76711]194    """
195    Convert *model_info* into a SasView model wrapper.
196    """
[fa5fd8d]197    def __init__(self, multiplicity=None):
198        SasviewModel.__init__(self, multiplicity=multiplicity)
199    attrs = _generate_model_attributes(model_info)
200    attrs['__init__'] = __init__
[2f2c70c]201    attrs['filename'] = model_info.filename
[60f03de]202    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
[ce27e21]203    return ConstructedModel
204
[fa5fd8d]205def _generate_model_attributes(model_info):
206    # type: (ModelInfo) -> Dict[str, Any]
207    """
208    Generate the class attributes for the model.
209
210    This should include all the information necessary to query the model
211    details so that you do not need to instantiate a model to query it.
212
213    All the attributes should be immutable to avoid accidents.
214    """
215
216    # TODO: allow model to override axis labels input/output name/unit
217
[a18c5b3]218    # Process multiplicity
[fa5fd8d]219    non_fittable = []  # type: List[str]
[04045f4]220    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
221    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]222    for p in model_info.parameters.kernel_parameters:
[04045f4]223        if p.name == model_info.control:
[fa5fd8d]224            non_fittable.append(p.name)
[04045f4]225            variants = MultiplicityInfo(
[ce176ca]226                len(p.choices) if p.choices else int(p.limits[1]),
227                p.name, p.choices, xlabel
[fa5fd8d]228            )
229            break
230
[50ec515]231    # Only a single drop-down list parameter available
232    fun_list = []
233    for p in model_info.parameters.kernel_parameters:
234        if p.choices:
235            fun_list = p.choices
236            if p.length > 1:
237                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
238            break
239
[a18c5b3]240    # Organize parameter sets
[fa5fd8d]241    orientation_params = []
242    magnetic_params = []
243    fixed = []
[85fe7f8]244    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]245        if p.type == 'orientation':
246            orientation_params.append(p.name)
247            orientation_params.append(p.name+".width")
248            fixed.append(p.name+".width")
[32e3c9b]249        elif p.type == 'magnetic':
[fa5fd8d]250            orientation_params.append(p.name)
251            magnetic_params.append(p.name)
252            fixed.append(p.name+".width")
[a18c5b3]253
[32e3c9b]254
[a18c5b3]255    # Build class dictionary
256    attrs = {}  # type: Dict[str, Any]
257    attrs['_model_info'] = model_info
258    attrs['name'] = model_info.name
259    attrs['id'] = model_info.id
260    attrs['description'] = model_info.description
261    attrs['category'] = model_info.category
262    attrs['is_structure_factor'] = model_info.structure_factor
263    attrs['is_form_factor'] = model_info.ER is not None
264    attrs['is_multiplicity_model'] = variants[0] > 1
265    attrs['multiplicity_info'] = variants
[fa5fd8d]266    attrs['orientation_params'] = tuple(orientation_params)
267    attrs['magnetic_params'] = tuple(magnetic_params)
268    attrs['fixed'] = tuple(fixed)
269    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]270    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]271
272    return attrs
[4d76711]273
[ce27e21]274class SasviewModel(object):
275    """
276    Sasview wrapper for opencl/ctypes model.
277    """
[fa5fd8d]278    # Model parameters for the specific model are set in the class constructor
279    # via the _generate_model_attributes function, which subclasses
280    # SasviewModel.  They are included here for typing and documentation
281    # purposes.
282    _model = None       # type: KernelModel
283    _model_info = None  # type: ModelInfo
284    #: load/save name for the model
285    id = None           # type: str
286    #: display name for the model
287    name = None         # type: str
288    #: short model description
289    description = None  # type: str
290    #: default model category
291    category = None     # type: str
292
293    #: names of the orientation parameters in the order they appear
[724257c]294    orientation_params = None # type: List[str]
[fa5fd8d]295    #: names of the magnetic parameters in the order they appear
[724257c]296    magnetic_params = None    # type: List[str]
[fa5fd8d]297    #: names of the fittable parameters
[724257c]298    fixed = None              # type: List[str]
[fa5fd8d]299    # TODO: the attribute fixed is ill-named
300
301    # Axis labels
302    input_name = "Q"
303    input_unit = "A^{-1}"
304    output_name = "Intensity"
305    output_unit = "cm^{-1}"
306
307    #: default cutoff for polydispersity
308    cutoff = 1e-5
309
310    # Note: Use non-mutable values for class attributes to avoid errors
311    #: parameters that are not fitted
312    non_fittable = ()        # type: Sequence[str]
313
314    #: True if model should appear as a structure factor
315    is_structure_factor = False
316    #: True if model should appear as a form factor
317    is_form_factor = False
318    #: True if model has multiplicity
319    is_multiplicity_model = False
320    #: Mulitplicity information
321    multiplicity_info = None # type: MultiplicityInfoType
322
323    # Per-instance variables
324    #: parameter {name: value} mapping
325    params = None      # type: Dict[str, float]
326    #: values for dispersion width, npts, nsigmas and type
327    dispersion = None  # type: Dict[str, Any]
328    #: units and limits for each parameter
[60f03de]329    details = None     # type: Dict[str, Sequence[Any]]
330    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]331    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]332    multiplicity = None     # type: Optional[int]
[04dc697]333    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
334    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]335
336    def __init__(self, multiplicity=None):
[04dc697]337        # type: (Optional[int]) -> None
[2622b3f]338
[04045f4]339        # TODO: _persistency_dict to persistency_dict throughout sasview
340        # TODO: refactor multiplicity to encompass variants
341        # TODO: dispersion should be a class
[fa5fd8d]342        # TODO: refactor multiplicity info
343        # TODO: separate profile view from multiplicity
344        # The button label, x and y axis labels and scale need to be under
345        # the control of the model, not the fit page.  Maximum flexibility,
346        # the fit page would supply the canvas and the profile could plot
347        # how it wants, but this assumes matplotlib.  Next level is that
348        # we provide some sort of data description including title, labels
349        # and lines to plot.
350
[04045f4]351        # Get the list of hidden parameters given the mulitplicity
352        # Don't include multiplicity in the list of parameters
[fa5fd8d]353        self.multiplicity = multiplicity
[04045f4]354        if multiplicity is not None:
355            hidden = self._model_info.get_hidden_parameters(multiplicity)
356            hidden |= set([self.multiplicity_info.control])
357        else:
358            hidden = set()
[8f93522]359        if self._model_info.structure_factor:
360            hidden.add('scale')
361            hidden.add('background')
362            self._model_info.parameters.defaults['background'] = 0.
[04045f4]363
[04dc697]364        self._persistency_dict = {}
[fa5fd8d]365        self.params = collections.OrderedDict()
[b3a85cd]366        self.dispersion = collections.OrderedDict()
[fa5fd8d]367        self.details = {}
[8977226]368        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]369            if p.name in hidden:
[fa5fd8d]370                continue
[fcd7bbd]371            self.params[p.name] = p.default
[fa5fd8d]372            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]373            if p.polydisperse:
[fa5fd8d]374                self.details[p.id+".width"] = [
375                    "", 0.0, 1.0 if p.relative_pd else np.inf
376                ]
[fb5914f]377                self.dispersion[p.name] = {
378                    'width': 0,
379                    'npts': 35,
380                    'nsigmas': 3,
381                    'type': 'gaussian',
382                }
[ce27e21]383
[de97440]384    def __get_state__(self):
[fa5fd8d]385        # type: () -> Dict[str, Any]
[de97440]386        state = self.__dict__.copy()
[4d76711]387        state.pop('_model')
[de97440]388        # May need to reload model info on set state since it has pointers
389        # to python implementations of Iq, etc.
390        #state.pop('_model_info')
391        return state
392
393    def __set_state__(self, state):
[fa5fd8d]394        # type: (Dict[str, Any]) -> None
[de97440]395        self.__dict__ = state
[fb5914f]396        self._model = None
[de97440]397
[ce27e21]398    def __str__(self):
[fa5fd8d]399        # type: () -> str
[ce27e21]400        """
401        :return: string representation
402        """
403        return self.name
404
405    def is_fittable(self, par_name):
[fa5fd8d]406        # type: (str) -> bool
[ce27e21]407        """
408        Check if a given parameter is fittable or not
409
410        :param par_name: the parameter name to check
411        """
[e758662]412        return par_name in self.fixed
[ce27e21]413        #For the future
414        #return self.params[str(par_name)].is_fittable()
415
416
417    def getProfile(self):
[fa5fd8d]418        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]419        """
420        Get SLD profile
421
422        : return: (z, beta) where z is a list of depth of the transition points
423                beta is a list of the corresponding SLD values
424        """
[745b7bb]425        args = {} # type: Dict[str, Any]
[fa5fd8d]426        for p in self._model_info.parameters.kernel_parameters:
427            if p.id == self.multiplicity_info.control:
[745b7bb]428                value = float(self.multiplicity)
[fa5fd8d]429            elif p.length == 1:
[745b7bb]430                value = self.params.get(p.id, np.NaN)
[fa5fd8d]431            else:
[745b7bb]432                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]433                                  for k in range(1, p.length+1)])
[745b7bb]434            args[p.id] = value
435
[e7fe459]436        x, y = self._model_info.profile(**args)
437        return x, 1e-6*y
[ce27e21]438
439    def setParam(self, name, value):
[fa5fd8d]440        # type: (str, float) -> None
[ce27e21]441        """
442        Set the value of a model parameter
443
444        :param name: name of the parameter
445        :param value: value of the parameter
446
447        """
448        # Look for dispersion parameters
449        toks = name.split('.')
[de0c4ba]450        if len(toks) == 2:
[ce27e21]451            for item in self.dispersion.keys():
[e758662]452                if item == toks[0]:
[ce27e21]453                    for par in self.dispersion[item]:
[e758662]454                        if par == toks[1]:
[ce27e21]455                            self.dispersion[item][par] = value
456                            return
457        else:
458            # Look for standard parameter
459            for item in self.params.keys():
[e758662]460                if item == name:
[ce27e21]461                    self.params[item] = value
462                    return
463
[63b32bb]464        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]465
466    def getParam(self, name):
[fa5fd8d]467        # type: (str) -> float
[ce27e21]468        """
469        Set the value of a model parameter
470
471        :param name: name of the parameter
472
473        """
474        # Look for dispersion parameters
475        toks = name.split('.')
[de0c4ba]476        if len(toks) == 2:
[ce27e21]477            for item in self.dispersion.keys():
[e758662]478                if item == toks[0]:
[ce27e21]479                    for par in self.dispersion[item]:
[e758662]480                        if par == toks[1]:
[ce27e21]481                            return self.dispersion[item][par]
482        else:
483            # Look for standard parameter
484            for item in self.params.keys():
[e758662]485                if item == name:
[ce27e21]486                    return self.params[item]
487
[63b32bb]488        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]489
490    def getParamList(self):
[04dc697]491        # type: () -> Sequence[str]
[ce27e21]492        """
493        Return a list of all available parameters for the model
494        """
[04dc697]495        param_list = list(self.params.keys())
[ce27e21]496        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]497        param_list.extend(self.getDispParamList())
498        return param_list
[ce27e21]499
500    def getDispParamList(self):
[04dc697]501        # type: () -> Sequence[str]
[ce27e21]502        """
[fb5914f]503        Return a list of polydispersity parameters for the model
[ce27e21]504        """
[1780d59]505        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]506        ret = ['%s.%s' % (p_name, ext)
507               for p_name in self.dispersion.keys()
508               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]509        #print(ret)
[1780d59]510        return ret
[ce27e21]511
512    def clone(self):
[04dc697]513        # type: () -> "SasviewModel"
[ce27e21]514        """ Return a identical copy of self """
515        return deepcopy(self)
516
517    def run(self, x=0.0):
[fa5fd8d]518        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]519        """
520        Evaluate the model
521
522        :param x: input q, or [q,phi]
523
524        :return: scattering function P(q)
525
526        **DEPRECATED**: use calculate_Iq instead
527        """
[de0c4ba]528        if isinstance(x, (list, tuple)):
[3c56da87]529            # pylint: disable=unpacking-non-sequence
[ce27e21]530            q, phi = x
[60f03de]531            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]532        else:
[60f03de]533            return self.calculate_Iq([x])[0]
[ce27e21]534
535
536    def runXY(self, x=0.0):
[fa5fd8d]537        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]538        """
539        Evaluate the model in cartesian coordinates
540
541        :param x: input q, or [qx, qy]
542
543        :return: scattering function P(q)
544
545        **DEPRECATED**: use calculate_Iq instead
546        """
[de0c4ba]547        if isinstance(x, (list, tuple)):
[60f03de]548            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]549        else:
[60f03de]550            return self.calculate_Iq([x])[0]
[ce27e21]551
552    def evalDistribution(self, qdist):
[04dc697]553        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]554        r"""
[ce27e21]555        Evaluate a distribution of q-values.
556
[d138d43]557        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]558
[d138d43]559        * For 1D, a numpy array is expected as input
[ce27e21]560
[d138d43]561        ::
[ce27e21]562
[d138d43]563            evalDistribution(q)
[ce27e21]564
[d138d43]565          where *q* is a numpy array.
[ce27e21]566
[d138d43]567        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]568
[d138d43]569        ::
[ce27e21]570
[d138d43]571              qx = [ qx[0], qx[1], qx[2], ....]
572              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]573
[d138d43]574        If the model is 1D only, then
[ce27e21]575
[d138d43]576        .. math::
[ce27e21]577
[d138d43]578            q = \sqrt{q_x^2+q_y^2}
[ce27e21]579
580        """
[de0c4ba]581        if isinstance(qdist, (list, tuple)):
[ce27e21]582            # Check whether we have a list of ndarrays [qx,qy]
583            qx, qy = qdist
[6d6508e]584            if not self._model_info.parameters.has_2d:
[de0c4ba]585                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]586            else:
587                return self.calculate_Iq(qx, qy)
[ce27e21]588
589        elif isinstance(qdist, np.ndarray):
590            # We have a simple 1D distribution of q-values
591            return self.calculate_Iq(qdist)
592
593        else:
[3c56da87]594            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
595                            % type(qdist))
[ce27e21]596
[64614ad]597    def get_composition_models(self):
598        """
599            Returns usable models that compose this model
600        """
601        s_model = None
602        p_model = None
603        if hasattr(self._model_info, "composition") \
604           and self._model_info.composition is not None:
605            p_model = _make_model_from_info(self._model_info.composition[1][0])()
606            s_model = _make_model_from_info(self._model_info.composition[1][1])()
607        return p_model, s_model
608
[fa5fd8d]609    def calculate_Iq(self, qx, qy=None):
610        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]611        """
612        Calculate Iq for one set of q with the current parameters.
613
614        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
615
616        This should NOT be used for fitting since it copies the *q* vectors
617        to the card for each evaluation.
618        """
[a38b065]619        ## uncomment the following when trying to debug the uncoordinated calls
620        ## to calculate_Iq
621        #if calculation_lock.locked():
[724257c]622        #    logger.info("calculation waiting for another thread to complete")
623        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]624
625        with calculation_lock:
626            return self._calculate_Iq(qx, qy)
627
628    def _calculate_Iq(self, qx, qy=None):
[6a0d6aa]629        #core.HAVE_OPENCL = False
[fb5914f]630        if self._model is None:
[d2bb604]631            self._model = core.build_model(self._model_info)
[fa5fd8d]632        if qy is not None:
633            q_vectors = [np.asarray(qx), np.asarray(qy)]
634        else:
635            q_vectors = [np.asarray(qx)]
[a738209]636        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]637        parameters = self._model_info.parameters
638        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]639        #weights.plot_weights(self._model_info, pairs)
[bde38b5]640        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]641        #call_details.show()
642        #print("pairs", pairs)
643        #print("params", self.params)
644        #print("values", values)
645        #print("is_mag", is_magnetic)
[6a0d6aa]646        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]647                            magnetic=is_magnetic)
[a738209]648        calculator.release()
[9f37726]649        self._model.release()
[ce27e21]650        return result
651
652    def calculate_ER(self):
[fa5fd8d]653        # type: () -> float
[ce27e21]654        """
655        Calculate the effective radius for P(q)*S(q)
656
657        :return: the value of the effective radius
658        """
[4bfd277]659        if self._model_info.ER is None:
[ce27e21]660            return 1.0
661        else:
[4bfd277]662            value, weight = self._dispersion_mesh()
663            fv = self._model_info.ER(*value)
[9404dd3]664            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]665            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]666
667    def calculate_VR(self):
[fa5fd8d]668        # type: () -> float
[ce27e21]669        """
670        Calculate the volf ratio for P(q)*S(q)
671
672        :return: the value of the volf ratio
673        """
[4bfd277]674        if self._model_info.VR is None:
[ce27e21]675            return 1.0
676        else:
[4bfd277]677            value, weight = self._dispersion_mesh()
678            whole, part = self._model_info.VR(*value)
679            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]680
681    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]682        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]683        """
684        Set the dispersion object for a model parameter
685
686        :param parameter: name of the parameter [string]
687        :param dispersion: dispersion object of type Dispersion
688        """
[fa800e72]689        if parameter in self.params:
[1780d59]690            # TODO: Store the disperser object directly in the model.
[56b2687]691            # The current method of relying on the sasview GUI to
[fa800e72]692            # remember them is kind of funky.
[1780d59]693            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]694            # (1) Could create a sasview model that has not yet been
[1780d59]695            # converted, assign the disperser to one of its polydisperse
696            # parameters, then retrieve the disperser parameters from the
[9c1a59c]697            # sasview model.
698            # (2) Could write a disperser parameter retriever in sasview.
699            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]700            # For now, rely on the fact that the sasview only ever uses
701            # new dispersers in the set_dispersion call and create a new
702            # one instead of trying to assign parameters.
[ce27e21]703            self.dispersion[parameter] = dispersion.get_pars()
704        else:
705            raise ValueError("%r is not a dispersity or orientation parameter")
706
[aa4946b]707    def _dispersion_mesh(self):
[fa5fd8d]708        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]709        """
710        Create a mesh grid of dispersion parameters and weights.
711
712        Returns [p1,p2,...],w where pj is a vector of values for parameter j
713        and w is a vector containing the products for weights for each
714        parameter set in the vector.
715        """
[4bfd277]716        pars = [self._get_weights(p)
717                for p in self._model_info.parameters.call_parameters
718                if p.type == 'volume']
[9eb3632]719        return dispersion_mesh(self._model_info, pars)
[ce27e21]720
721    def _get_weights(self, par):
[fa5fd8d]722        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]723        """
[fb5914f]724        Return dispersion weights for parameter
[de0c4ba]725        """
[fa5fd8d]726        if par.name not in self.params:
727            if par.name == self.multiplicity_info.control:
[4edec6f]728                return [self.multiplicity], [1.0]
[fa5fd8d]729            else:
[8f93522]730                # For hidden parameters use the default value.
731                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
732                return [value], [1.0]
[fa5fd8d]733        elif par.polydisperse:
[fb5914f]734            dis = self.dispersion[par.name]
[9c1a59c]735            if dis['type'] == 'array':
736                value, weight = dis['values'], dis['weights']
737            else:
738                value, weight = weights.get_weights(
739                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
740                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]741            return value, weight / np.sum(weight)
742        else:
[4edec6f]743            return [self.params[par.name]], [1.0]
[ce27e21]744
[749a7d4]745def test_cylinder():
[fa5fd8d]746    # type: () -> float
[4d76711]747    """
[749a7d4]748    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]749    """
750    Cylinder = _make_standard_model('cylinder')
[fb5914f]751    cylinder = Cylinder()
[b32dafd]752    return cylinder.evalDistribution([0.1, 0.1])
[de97440]753
[8f93522]754def test_structure_factor():
755    # type: () -> float
756    """
[749a7d4]757    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]758    """
759    Model = _make_standard_model('hardsphere')
760    model = Model()
761    value = model.evalDistribution([0.1, 0.1])
762    if np.isnan(value):
763        raise ValueError("hardsphere returns null")
764
[04045f4]765def test_rpa():
766    # type: () -> float
767    """
[749a7d4]768    Test that the 2-D RPA model runs
[04045f4]769    """
770    RPA = _make_standard_model('rpa')
771    rpa = RPA(3)
[b32dafd]772    return rpa.evalDistribution([0.1, 0.1])
[04045f4]773
[749a7d4]774def test_empty_distribution():
775    # type: () -> None
776    """
777    Make sure that sasmodels returns NaN when there are no polydispersity points
778    """
779    Cylinder = _make_standard_model('cylinder')
780    cylinder = Cylinder()
781    cylinder.setParam('radius', -1.0)
782    cylinder.setParam('background', 0.)
783    Iq = cylinder.evalDistribution(np.asarray([0.1]))
784    assert np.isnan(Iq[0]), "empty distribution fails"
[4d76711]785
786def test_model_list():
[fa5fd8d]787    # type: () -> None
[4d76711]788    """
[749a7d4]789    Make sure that all models build as sasview models
[4d76711]790    """
791    from .exception import annotate_exception
792    for name in core.list_models():
793        try:
794            _make_standard_model(name)
795        except:
796            annotate_exception("when loading "+name)
797            raise
798
[c95dfc63]799def test_old_name():
800    # type: () -> None
801    """
802    Load and run cylinder model from sas.models.CylinderModel
803    """
804    if not SUPPORT_OLD_STYLE_PLUGINS:
805        return
806    try:
807        # if sasview is not on the path then don't try to test it
808        import sas
809    except ImportError:
810        return
811    load_standard_models()
812    from sas.models.CylinderModel import CylinderModel
813    CylinderModel().evalDistribution([0.1, 0.1])
814
[fb5914f]815if __name__ == "__main__":
[749a7d4]816    print("cylinder(0.1,0.1)=%g"%test_cylinder())
817    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.