source: sasmodels/sasmodels/sasview_model.py @ dd4f5ed

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since dd4f5ed was 9644b5a, checked in by Paul Kienzle <pkienzle@…>, 7 years ago

avoid recalculation of components unless it is a P@S product model

  • Property mode set to 100644
File size: 29.7 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[a38b065]18import thread
[ce27e21]19
[7ae2b7f]20import numpy as np  # type: ignore
[ce27e21]21
[aa4946b]22from . import core
[4d76711]23from . import custom
[a80e64c]24from . import product
[72a081d]25from . import generate
[fb5914f]26from . import weights
[6d6508e]27from . import modelinfo
[bde38b5]28from .details import make_kernel_args, dispersion_mesh
[ff7119b]29
[fa5fd8d]30try:
[60f03de]31    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]32    from .modelinfo import ModelInfo, Parameter
33    from .kernel import KernelModel
34    MultiplicityInfoType = NamedTuple(
35        'MuliplicityInfo',
36        [("number", int), ("control", str), ("choices", List[str]),
37         ("x_axis_label", str)])
[60f03de]38    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]39except ImportError:
40    pass
41
[724257c]42logger = logging.getLogger(__name__)
43
[a38b065]44calculation_lock = thread.allocate_lock()
45
[724257c]46#: True if pre-existing plugins, with the old names and parameters, should
47#: continue to be supported.
[c95dfc63]48SUPPORT_OLD_STYLE_PLUGINS = True
49
[fa5fd8d]50# TODO: separate x_axis_label from multiplicity info
51MultiplicityInfo = collections.namedtuple(
52    'MultiplicityInfo',
53    ["number", "control", "choices", "x_axis_label"],
54)
55
[724257c]56#: set of defined models (standard and custom)
57MODELS = {}  # type: Dict[str, SasviewModelType]
58#: custom model {path: model} mapping so we can check timestamps
59MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
60
[92d38285]61def find_model(modelname):
[b32dafd]62    # type: (str) -> SasviewModelType
63    """
64    Find a model by name.  If the model name ends in py, try loading it from
65    custom models, otherwise look for it in the list of builtin models.
66    """
[92d38285]67    # TODO: used by sum/product model to load an existing model
68    # TODO: doesn't handle custom models properly
69    if modelname.endswith('.py'):
70        return load_custom_model(modelname)
71    elif modelname in MODELS:
72        return MODELS[modelname]
73    else:
74        raise ValueError("unknown model %r"%modelname)
75
[56b2687]76
[fa5fd8d]77# TODO: figure out how to say that the return type is a subclass
[4d76711]78def load_standard_models():
[60f03de]79    # type: () -> List[SasviewModelType]
[4d76711]80    """
81    Load and return the list of predefined models.
82
83    If there is an error loading a model, then a traceback is logged and the
84    model is not returned.
85    """
86    for name in core.list_models():
87        try:
[92d38285]88            MODELS[name] = _make_standard_model(name)
[ee8f734]89        except Exception:
[724257c]90            logger.error(traceback.format_exc())
[c95dfc63]91    if SUPPORT_OLD_STYLE_PLUGINS:
92        _register_old_models()
93
[724257c]94    return list(MODELS.values())
[de97440]95
[4d76711]96
97def load_custom_model(path):
[60f03de]98    # type: (str) -> SasviewModelType
[4d76711]99    """
100    Load a custom model given the model path.
[ff7119b]101    """
[724257c]102    model = MODEL_BY_PATH.get(path, None)
103    if model is not None and model.timestamp == getmtime(path):
104        #logger.info("Model already loaded %s", path)
105        return model
106
107    #logger.info("Loading model %s", path)
[4d76711]108    kernel_module = custom.load_custom_kernel_module(path)
[724257c]109    if hasattr(kernel_module, 'Model'):
[92d38285]110        model = kernel_module.Model
[9457498]111        # Old style models do not set the name in the class attributes, so
112        # set it here; this name will be overridden when the object is created
113        # with an instance variable that has the same value.
114        if model.name == "":
115            model.name = splitext(basename(path))[0]
[20a70bc]116        if not hasattr(model, 'filename'):
[724257c]117            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]118        if not hasattr(model, 'id'):
119            model.id = splitext(basename(model.filename))[0]
[724257c]120    else:
[56b2687]121        model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]122        model = make_model_from_info(model_info)
[724257c]123    model.timestamp = getmtime(path)
[ed10b57]124
[2f2c70c]125    # If a model name already exists and we are loading a different model,
126    # use the model file name as the model name.
127    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
128        _previous_name = model.name
129        model.name = model.id
[bf8c271]130
[2f2c70c]131        # If the new model name is still in the model list (for instance,
132        # if we put a cylinder.py in our plug-in directory), then append
133        # an identifier.
134        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
135            model.name = model.id + '_user'
[724257c]136        logger.info("Model %s already exists: using %s [%s]",
137                    _previous_name, model.name, model.filename)
[ed10b57]138
[92d38285]139    MODELS[model.name] = model
[724257c]140    MODEL_BY_PATH[path] = model
[92d38285]141    return model
[4d76711]142
[87985ca]143
[bcdd6c9]144def make_model_from_info(model_info):
145    # type: (ModelInfo) -> SasviewModelType
146    """
147    Convert *model_info* into a SasView model wrapper.
148    """
149    def __init__(self, multiplicity=None):
150        SasviewModel.__init__(self, multiplicity=multiplicity)
151    attrs = _generate_model_attributes(model_info)
152    attrs['__init__'] = __init__
153    attrs['filename'] = model_info.filename
154    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
155    return ConstructedModel
156
157
[4d76711]158def _make_standard_model(name):
[60f03de]159    # type: (str) -> SasviewModelType
[ff7119b]160    """
[4d76711]161    Load the sasview model defined by *name*.
[72a081d]162
[4d76711]163    *name* can be a standard model name or a path to a custom model.
[87985ca]164
[4d76711]165    Returns a class that can be used directly as a sasview model.
[ff7119b]166    """
[4d76711]167    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]168    model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]169    return make_model_from_info(model_info)
[72a081d]170
171
[724257c]172def _register_old_models():
173    # type: () -> None
174    """
175    Place the new models into sasview under the old names.
176
177    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
178    is available to the plugin modules.
179    """
180    import sys
181    import sas   # needed in order to set sas.models
182    import sas.sascalc.fit
183    sys.modules['sas.models'] = sas.sascalc.fit
184    sas.models = sas.sascalc.fit
185
186    import sas.models
187    from sasmodels.conversion_table import CONVERSION_TABLE
188    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
189        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
190        new_name = new_name.split(':')[0]
191        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
192        module_attrs = {old_name: find_model(new_name)}
193        ConstructedModule = type(old_name, (), module_attrs)
194        old_path = 'sas.models.' + old_name
195        setattr(sas.models, old_path, ConstructedModule)
196        sys.modules[old_path] = ConstructedModule
197
198
[a80e64c]199def MultiplicationModel(form_factor, structure_factor):
200    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
201    model_info = product.make_product_info(form_factor._model_info,
202                                           structure_factor._model_info)
[bcdd6c9]203    ConstructedModel = make_model_from_info(model_info)
[a80e64c]204    return ConstructedModel()
205
[ce27e21]206
[fa5fd8d]207def _generate_model_attributes(model_info):
208    # type: (ModelInfo) -> Dict[str, Any]
209    """
210    Generate the class attributes for the model.
211
212    This should include all the information necessary to query the model
213    details so that you do not need to instantiate a model to query it.
214
215    All the attributes should be immutable to avoid accidents.
216    """
217
218    # TODO: allow model to override axis labels input/output name/unit
219
[a18c5b3]220    # Process multiplicity
[fa5fd8d]221    non_fittable = []  # type: List[str]
[04045f4]222    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
223    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]224    for p in model_info.parameters.kernel_parameters:
[04045f4]225        if p.name == model_info.control:
[fa5fd8d]226            non_fittable.append(p.name)
[04045f4]227            variants = MultiplicityInfo(
[ce176ca]228                len(p.choices) if p.choices else int(p.limits[1]),
229                p.name, p.choices, xlabel
[fa5fd8d]230            )
231            break
232
[50ec515]233    # Only a single drop-down list parameter available
234    fun_list = []
235    for p in model_info.parameters.kernel_parameters:
236        if p.choices:
237            fun_list = p.choices
238            if p.length > 1:
239                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
240            break
241
[a18c5b3]242    # Organize parameter sets
[fa5fd8d]243    orientation_params = []
244    magnetic_params = []
245    fixed = []
[85fe7f8]246    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]247        if p.type == 'orientation':
248            orientation_params.append(p.name)
249            orientation_params.append(p.name+".width")
250            fixed.append(p.name+".width")
[32e3c9b]251        elif p.type == 'magnetic':
[fa5fd8d]252            orientation_params.append(p.name)
253            magnetic_params.append(p.name)
254            fixed.append(p.name+".width")
[a18c5b3]255
[32e3c9b]256
[a18c5b3]257    # Build class dictionary
258    attrs = {}  # type: Dict[str, Any]
259    attrs['_model_info'] = model_info
260    attrs['name'] = model_info.name
261    attrs['id'] = model_info.id
262    attrs['description'] = model_info.description
263    attrs['category'] = model_info.category
264    attrs['is_structure_factor'] = model_info.structure_factor
265    attrs['is_form_factor'] = model_info.ER is not None
266    attrs['is_multiplicity_model'] = variants[0] > 1
267    attrs['multiplicity_info'] = variants
[fa5fd8d]268    attrs['orientation_params'] = tuple(orientation_params)
269    attrs['magnetic_params'] = tuple(magnetic_params)
270    attrs['fixed'] = tuple(fixed)
271    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]272    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]273
274    return attrs
[4d76711]275
[ce27e21]276class SasviewModel(object):
277    """
278    Sasview wrapper for opencl/ctypes model.
279    """
[fa5fd8d]280    # Model parameters for the specific model are set in the class constructor
281    # via the _generate_model_attributes function, which subclasses
282    # SasviewModel.  They are included here for typing and documentation
283    # purposes.
284    _model = None       # type: KernelModel
285    _model_info = None  # type: ModelInfo
286    #: load/save name for the model
287    id = None           # type: str
288    #: display name for the model
289    name = None         # type: str
290    #: short model description
291    description = None  # type: str
292    #: default model category
293    category = None     # type: str
294
295    #: names of the orientation parameters in the order they appear
[724257c]296    orientation_params = None # type: List[str]
[fa5fd8d]297    #: names of the magnetic parameters in the order they appear
[724257c]298    magnetic_params = None    # type: List[str]
[fa5fd8d]299    #: names of the fittable parameters
[724257c]300    fixed = None              # type: List[str]
[fa5fd8d]301    # TODO: the attribute fixed is ill-named
302
303    # Axis labels
304    input_name = "Q"
305    input_unit = "A^{-1}"
306    output_name = "Intensity"
307    output_unit = "cm^{-1}"
308
309    #: default cutoff for polydispersity
310    cutoff = 1e-5
311
312    # Note: Use non-mutable values for class attributes to avoid errors
313    #: parameters that are not fitted
314    non_fittable = ()        # type: Sequence[str]
315
316    #: True if model should appear as a structure factor
317    is_structure_factor = False
318    #: True if model should appear as a form factor
319    is_form_factor = False
320    #: True if model has multiplicity
321    is_multiplicity_model = False
322    #: Mulitplicity information
323    multiplicity_info = None # type: MultiplicityInfoType
324
325    # Per-instance variables
326    #: parameter {name: value} mapping
327    params = None      # type: Dict[str, float]
328    #: values for dispersion width, npts, nsigmas and type
329    dispersion = None  # type: Dict[str, Any]
330    #: units and limits for each parameter
[60f03de]331    details = None     # type: Dict[str, Sequence[Any]]
332    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]333    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]334    multiplicity = None     # type: Optional[int]
[04dc697]335    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
336    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]337
338    def __init__(self, multiplicity=None):
[04dc697]339        # type: (Optional[int]) -> None
[2622b3f]340
[04045f4]341        # TODO: _persistency_dict to persistency_dict throughout sasview
342        # TODO: refactor multiplicity to encompass variants
343        # TODO: dispersion should be a class
[fa5fd8d]344        # TODO: refactor multiplicity info
345        # TODO: separate profile view from multiplicity
346        # The button label, x and y axis labels and scale need to be under
347        # the control of the model, not the fit page.  Maximum flexibility,
348        # the fit page would supply the canvas and the profile could plot
349        # how it wants, but this assumes matplotlib.  Next level is that
350        # we provide some sort of data description including title, labels
351        # and lines to plot.
352
[04045f4]353        # Get the list of hidden parameters given the mulitplicity
354        # Don't include multiplicity in the list of parameters
[fa5fd8d]355        self.multiplicity = multiplicity
[04045f4]356        if multiplicity is not None:
357            hidden = self._model_info.get_hidden_parameters(multiplicity)
358            hidden |= set([self.multiplicity_info.control])
359        else:
360            hidden = set()
[8f93522]361        if self._model_info.structure_factor:
362            hidden.add('scale')
363            hidden.add('background')
364            self._model_info.parameters.defaults['background'] = 0.
[04045f4]365
[04dc697]366        self._persistency_dict = {}
[fa5fd8d]367        self.params = collections.OrderedDict()
[b3a85cd]368        self.dispersion = collections.OrderedDict()
[fa5fd8d]369        self.details = {}
[8977226]370        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]371            if p.name in hidden:
[fa5fd8d]372                continue
[fcd7bbd]373            self.params[p.name] = p.default
[fa5fd8d]374            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]375            if p.polydisperse:
[fa5fd8d]376                self.details[p.id+".width"] = [
377                    "", 0.0, 1.0 if p.relative_pd else np.inf
378                ]
[fb5914f]379                self.dispersion[p.name] = {
380                    'width': 0,
381                    'npts': 35,
382                    'nsigmas': 3,
383                    'type': 'gaussian',
384                }
[ce27e21]385
[de97440]386    def __get_state__(self):
[fa5fd8d]387        # type: () -> Dict[str, Any]
[de97440]388        state = self.__dict__.copy()
[4d76711]389        state.pop('_model')
[de97440]390        # May need to reload model info on set state since it has pointers
391        # to python implementations of Iq, etc.
392        #state.pop('_model_info')
393        return state
394
395    def __set_state__(self, state):
[fa5fd8d]396        # type: (Dict[str, Any]) -> None
[de97440]397        self.__dict__ = state
[fb5914f]398        self._model = None
[de97440]399
[ce27e21]400    def __str__(self):
[fa5fd8d]401        # type: () -> str
[ce27e21]402        """
403        :return: string representation
404        """
405        return self.name
406
407    def is_fittable(self, par_name):
[fa5fd8d]408        # type: (str) -> bool
[ce27e21]409        """
410        Check if a given parameter is fittable or not
411
412        :param par_name: the parameter name to check
413        """
[e758662]414        return par_name in self.fixed
[ce27e21]415        #For the future
416        #return self.params[str(par_name)].is_fittable()
417
418
419    def getProfile(self):
[fa5fd8d]420        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]421        """
422        Get SLD profile
423
424        : return: (z, beta) where z is a list of depth of the transition points
425                beta is a list of the corresponding SLD values
426        """
[745b7bb]427        args = {} # type: Dict[str, Any]
[fa5fd8d]428        for p in self._model_info.parameters.kernel_parameters:
429            if p.id == self.multiplicity_info.control:
[745b7bb]430                value = float(self.multiplicity)
[fa5fd8d]431            elif p.length == 1:
[745b7bb]432                value = self.params.get(p.id, np.NaN)
[fa5fd8d]433            else:
[745b7bb]434                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]435                                  for k in range(1, p.length+1)])
[745b7bb]436            args[p.id] = value
437
[e7fe459]438        x, y = self._model_info.profile(**args)
439        return x, 1e-6*y
[ce27e21]440
441    def setParam(self, name, value):
[fa5fd8d]442        # type: (str, float) -> None
[ce27e21]443        """
444        Set the value of a model parameter
445
446        :param name: name of the parameter
447        :param value: value of the parameter
448
449        """
450        # Look for dispersion parameters
451        toks = name.split('.')
[de0c4ba]452        if len(toks) == 2:
[ce27e21]453            for item in self.dispersion.keys():
[e758662]454                if item == toks[0]:
[ce27e21]455                    for par in self.dispersion[item]:
[e758662]456                        if par == toks[1]:
[ce27e21]457                            self.dispersion[item][par] = value
458                            return
459        else:
460            # Look for standard parameter
461            for item in self.params.keys():
[e758662]462                if item == name:
[ce27e21]463                    self.params[item] = value
464                    return
465
[63b32bb]466        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]467
468    def getParam(self, name):
[fa5fd8d]469        # type: (str) -> float
[ce27e21]470        """
471        Set the value of a model parameter
472
473        :param name: name of the parameter
474
475        """
476        # Look for dispersion parameters
477        toks = name.split('.')
[de0c4ba]478        if len(toks) == 2:
[ce27e21]479            for item in self.dispersion.keys():
[e758662]480                if item == toks[0]:
[ce27e21]481                    for par in self.dispersion[item]:
[e758662]482                        if par == toks[1]:
[ce27e21]483                            return self.dispersion[item][par]
484        else:
485            # Look for standard parameter
486            for item in self.params.keys():
[e758662]487                if item == name:
[ce27e21]488                    return self.params[item]
489
[63b32bb]490        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]491
492    def getParamList(self):
[04dc697]493        # type: () -> Sequence[str]
[ce27e21]494        """
495        Return a list of all available parameters for the model
496        """
[04dc697]497        param_list = list(self.params.keys())
[ce27e21]498        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]499        param_list.extend(self.getDispParamList())
500        return param_list
[ce27e21]501
502    def getDispParamList(self):
[04dc697]503        # type: () -> Sequence[str]
[ce27e21]504        """
[fb5914f]505        Return a list of polydispersity parameters for the model
[ce27e21]506        """
[1780d59]507        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]508        ret = ['%s.%s' % (p_name, ext)
509               for p_name in self.dispersion.keys()
510               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]511        #print(ret)
[1780d59]512        return ret
[ce27e21]513
514    def clone(self):
[04dc697]515        # type: () -> "SasviewModel"
[ce27e21]516        """ Return a identical copy of self """
517        return deepcopy(self)
518
519    def run(self, x=0.0):
[fa5fd8d]520        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]521        """
522        Evaluate the model
523
524        :param x: input q, or [q,phi]
525
526        :return: scattering function P(q)
527
528        **DEPRECATED**: use calculate_Iq instead
529        """
[de0c4ba]530        if isinstance(x, (list, tuple)):
[3c56da87]531            # pylint: disable=unpacking-non-sequence
[ce27e21]532            q, phi = x
[60f03de]533            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]534        else:
[60f03de]535            return self.calculate_Iq([x])[0]
[ce27e21]536
537
538    def runXY(self, x=0.0):
[fa5fd8d]539        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]540        """
541        Evaluate the model in cartesian coordinates
542
543        :param x: input q, or [qx, qy]
544
545        :return: scattering function P(q)
546
547        **DEPRECATED**: use calculate_Iq instead
548        """
[de0c4ba]549        if isinstance(x, (list, tuple)):
[60f03de]550            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]551        else:
[60f03de]552            return self.calculate_Iq([x])[0]
[ce27e21]553
554    def evalDistribution(self, qdist):
[04dc697]555        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]556        r"""
[ce27e21]557        Evaluate a distribution of q-values.
558
[d138d43]559        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]560
[d138d43]561        * For 1D, a numpy array is expected as input
[ce27e21]562
[d138d43]563        ::
[ce27e21]564
[d138d43]565            evalDistribution(q)
[ce27e21]566
[d138d43]567          where *q* is a numpy array.
[ce27e21]568
[d138d43]569        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]570
[d138d43]571        ::
[ce27e21]572
[d138d43]573              qx = [ qx[0], qx[1], qx[2], ....]
574              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]575
[d138d43]576        If the model is 1D only, then
[ce27e21]577
[d138d43]578        .. math::
[ce27e21]579
[d138d43]580            q = \sqrt{q_x^2+q_y^2}
[ce27e21]581
582        """
[de0c4ba]583        if isinstance(qdist, (list, tuple)):
[ce27e21]584            # Check whether we have a list of ndarrays [qx,qy]
585            qx, qy = qdist
[6d6508e]586            if not self._model_info.parameters.has_2d:
[de0c4ba]587                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]588            else:
589                return self.calculate_Iq(qx, qy)
[ce27e21]590
591        elif isinstance(qdist, np.ndarray):
592            # We have a simple 1D distribution of q-values
593            return self.calculate_Iq(qdist)
594
595        else:
[3c56da87]596            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
597                            % type(qdist))
[ce27e21]598
[9dcb21d]599    def calc_composition_models(self, qx):
[64614ad]600        """
[9dcb21d]601        returns parts of the composition model or None if not a composition
602        model.
[64614ad]603        """
[946c8d27]604        # TODO: have calculate_Iq return the intermediates.
605        #
606        # The current interface causes calculate_Iq() to be called twice,
607        # once to get the combined result and again to get the intermediate
608        # results.  This is necessary for now.
609        # Long term, the solution is to change the interface to calculate_Iq
610        # so that it returns a results object containing all the bits:
[9644b5a]611        #     the A, B, C, ... of the composition model (and any subcomponents?)
[946c8d27]612        #     the P and S of the product model,
613        #     the combined model before resolution smearing,
614        #     the sasmodel before sesans conversion,
615        #     the oriented 2D model used to fit oriented usans data,
616        #     the final I(q),
617        #     ...
[9644b5a]618        #
[946c8d27]619        # Have the model calculator add all of these blindly to the data
620        # tree, and update the graphs which contain them.  The fitter
621        # needs to be updated to use the I(q) value only, ignoring the rest.
622        #
623        # The simple fix of returning the existing intermediate results
624        # will not work for a couple of reasons: (1) another thread may
625        # sneak in to compute its own results before calc_composition_models
626        # is called, and (2) calculate_Iq is currently called three times:
627        # once with q, once with q values before qmin and once with q values
628        # after q max.  Both of these should be addressed before
629        # replacing this code.
[9644b5a]630        composition = self._model_info.composition
631        if composition and composition[0] == 'product': # only P*S for now
632            with calculation_lock:
633                self._calculate_Iq(qx)
634                return self._intermediate_results
635        else:
636            return None
[bf8c271]637
[fa5fd8d]638    def calculate_Iq(self, qx, qy=None):
639        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]640        """
641        Calculate Iq for one set of q with the current parameters.
642
643        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
644
645        This should NOT be used for fitting since it copies the *q* vectors
646        to the card for each evaluation.
647        """
[a38b065]648        ## uncomment the following when trying to debug the uncoordinated calls
649        ## to calculate_Iq
650        #if calculation_lock.locked():
[724257c]651        #    logger.info("calculation waiting for another thread to complete")
652        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]653
654        with calculation_lock:
655            return self._calculate_Iq(qx, qy)
656
657    def _calculate_Iq(self, qx, qy=None):
[6a0d6aa]658        #core.HAVE_OPENCL = False
[fb5914f]659        if self._model is None:
[d2bb604]660            self._model = core.build_model(self._model_info)
[fa5fd8d]661        if qy is not None:
662            q_vectors = [np.asarray(qx), np.asarray(qy)]
663        else:
664            q_vectors = [np.asarray(qx)]
[a738209]665        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]666        parameters = self._model_info.parameters
667        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]668        #weights.plot_weights(self._model_info, pairs)
[bde38b5]669        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]670        #call_details.show()
671        #print("pairs", pairs)
672        #print("params", self.params)
673        #print("values", values)
674        #print("is_mag", is_magnetic)
[6a0d6aa]675        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]676                            magnetic=is_magnetic)
[bf8c271]677        self._intermediate_results = getattr(calculator, 'results', None)
[a738209]678        calculator.release()
[9f37726]679        self._model.release()
[ce27e21]680        return result
681
682    def calculate_ER(self):
[fa5fd8d]683        # type: () -> float
[ce27e21]684        """
685        Calculate the effective radius for P(q)*S(q)
686
687        :return: the value of the effective radius
688        """
[4bfd277]689        if self._model_info.ER is None:
[ce27e21]690            return 1.0
691        else:
[4bfd277]692            value, weight = self._dispersion_mesh()
693            fv = self._model_info.ER(*value)
[9404dd3]694            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]695            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]696
697    def calculate_VR(self):
[fa5fd8d]698        # type: () -> float
[ce27e21]699        """
700        Calculate the volf ratio for P(q)*S(q)
701
702        :return: the value of the volf ratio
703        """
[4bfd277]704        if self._model_info.VR is None:
[ce27e21]705            return 1.0
706        else:
[4bfd277]707            value, weight = self._dispersion_mesh()
708            whole, part = self._model_info.VR(*value)
709            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]710
711    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]712        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]713        """
714        Set the dispersion object for a model parameter
715
716        :param parameter: name of the parameter [string]
717        :param dispersion: dispersion object of type Dispersion
718        """
[fa800e72]719        if parameter in self.params:
[1780d59]720            # TODO: Store the disperser object directly in the model.
[56b2687]721            # The current method of relying on the sasview GUI to
[fa800e72]722            # remember them is kind of funky.
[1780d59]723            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]724            # (1) Could create a sasview model that has not yet been
[1780d59]725            # converted, assign the disperser to one of its polydisperse
726            # parameters, then retrieve the disperser parameters from the
[9c1a59c]727            # sasview model.
728            # (2) Could write a disperser parameter retriever in sasview.
729            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]730            # For now, rely on the fact that the sasview only ever uses
731            # new dispersers in the set_dispersion call and create a new
732            # one instead of trying to assign parameters.
[ce27e21]733            self.dispersion[parameter] = dispersion.get_pars()
734        else:
735            raise ValueError("%r is not a dispersity or orientation parameter")
736
[aa4946b]737    def _dispersion_mesh(self):
[fa5fd8d]738        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]739        """
740        Create a mesh grid of dispersion parameters and weights.
741
742        Returns [p1,p2,...],w where pj is a vector of values for parameter j
743        and w is a vector containing the products for weights for each
744        parameter set in the vector.
745        """
[4bfd277]746        pars = [self._get_weights(p)
747                for p in self._model_info.parameters.call_parameters
748                if p.type == 'volume']
[9eb3632]749        return dispersion_mesh(self._model_info, pars)
[ce27e21]750
751    def _get_weights(self, par):
[fa5fd8d]752        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]753        """
[fb5914f]754        Return dispersion weights for parameter
[de0c4ba]755        """
[fa5fd8d]756        if par.name not in self.params:
757            if par.name == self.multiplicity_info.control:
[4edec6f]758                return [self.multiplicity], [1.0]
[fa5fd8d]759            else:
[8f93522]760                # For hidden parameters use the default value.
761                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
762                return [value], [1.0]
[fa5fd8d]763        elif par.polydisperse:
[fb5914f]764            dis = self.dispersion[par.name]
[9c1a59c]765            if dis['type'] == 'array':
766                value, weight = dis['values'], dis['weights']
767            else:
768                value, weight = weights.get_weights(
769                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
770                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]771            return value, weight / np.sum(weight)
772        else:
[4edec6f]773            return [self.params[par.name]], [1.0]
[ce27e21]774
[749a7d4]775def test_cylinder():
[fa5fd8d]776    # type: () -> float
[4d76711]777    """
[749a7d4]778    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]779    """
780    Cylinder = _make_standard_model('cylinder')
[fb5914f]781    cylinder = Cylinder()
[b32dafd]782    return cylinder.evalDistribution([0.1, 0.1])
[de97440]783
[8f93522]784def test_structure_factor():
785    # type: () -> float
786    """
[749a7d4]787    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]788    """
789    Model = _make_standard_model('hardsphere')
790    model = Model()
791    value = model.evalDistribution([0.1, 0.1])
792    if np.isnan(value):
793        raise ValueError("hardsphere returns null")
794
[04045f4]795def test_rpa():
796    # type: () -> float
797    """
[749a7d4]798    Test that the 2-D RPA model runs
[04045f4]799    """
800    RPA = _make_standard_model('rpa')
801    rpa = RPA(3)
[b32dafd]802    return rpa.evalDistribution([0.1, 0.1])
[04045f4]803
[749a7d4]804def test_empty_distribution():
805    # type: () -> None
806    """
807    Make sure that sasmodels returns NaN when there are no polydispersity points
808    """
809    Cylinder = _make_standard_model('cylinder')
810    cylinder = Cylinder()
811    cylinder.setParam('radius', -1.0)
812    cylinder.setParam('background', 0.)
813    Iq = cylinder.evalDistribution(np.asarray([0.1]))
814    assert np.isnan(Iq[0]), "empty distribution fails"
[4d76711]815
816def test_model_list():
[fa5fd8d]817    # type: () -> None
[4d76711]818    """
[749a7d4]819    Make sure that all models build as sasview models
[4d76711]820    """
821    from .exception import annotate_exception
822    for name in core.list_models():
823        try:
824            _make_standard_model(name)
825        except:
826            annotate_exception("when loading "+name)
827            raise
828
[c95dfc63]829def test_old_name():
830    # type: () -> None
831    """
832    Load and run cylinder model from sas.models.CylinderModel
833    """
834    if not SUPPORT_OLD_STYLE_PLUGINS:
835        return
836    try:
837        # if sasview is not on the path then don't try to test it
838        import sas
839    except ImportError:
840        return
841    load_standard_models()
842    from sas.models.CylinderModel import CylinderModel
843    CylinderModel().evalDistribution([0.1, 0.1])
844
[fb5914f]845if __name__ == "__main__":
[749a7d4]846    print("cylinder(0.1,0.1)=%g"%test_cylinder())
847    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.