source: sasmodels/sasmodels/sasview_model.py @ edb0f85

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since edb0f85 was edb0f85, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

tweak implementation of sld profile for product models

  • Property mode set to 100644
File size: 29.8 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[9f8ade1]18try:
19    import _thread as thread
20except ImportError:
21    import thread
[ce27e21]22
[7ae2b7f]23import numpy as np  # type: ignore
[ce27e21]24
[aa4946b]25from . import core
[4d76711]26from . import custom
[a80e64c]27from . import product
[72a081d]28from . import generate
[fb5914f]29from . import weights
[6d6508e]30from . import modelinfo
[bde38b5]31from .details import make_kernel_args, dispersion_mesh
[ff7119b]32
[fa5fd8d]33try:
[60f03de]34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
38        'MuliplicityInfo',
39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
[60f03de]41    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]42except ImportError:
43    pass
44
[724257c]45logger = logging.getLogger(__name__)
46
[a38b065]47calculation_lock = thread.allocate_lock()
48
[724257c]49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
[c95dfc63]51SUPPORT_OLD_STYLE_PLUGINS = True
52
[fa5fd8d]53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
[724257c]59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
[92d38285]64def find_model(modelname):
[b32dafd]65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
[92d38285]70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
[56b2687]79
[fa5fd8d]80# TODO: figure out how to say that the return type is a subclass
[4d76711]81def load_standard_models():
[60f03de]82    # type: () -> List[SasviewModelType]
[4d76711]83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
[92d38285]91            MODELS[name] = _make_standard_model(name)
[ee8f734]92        except Exception:
[724257c]93            logger.error(traceback.format_exc())
[c95dfc63]94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
[724257c]97    return list(MODELS.values())
[de97440]98
[4d76711]99
100def load_custom_model(path):
[60f03de]101    # type: (str) -> SasviewModelType
[4d76711]102    """
103    Load a custom model given the model path.
[ff7119b]104    """
[724257c]105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
[4d76711]111    kernel_module = custom.load_custom_kernel_module(path)
[724257c]112    if hasattr(kernel_module, 'Model'):
[92d38285]113        model = kernel_module.Model
[9457498]114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
[20a70bc]119        if not hasattr(model, 'filename'):
[724257c]120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
[724257c]123    else:
[56b2687]124        model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]125        model = make_model_from_info(model_info)
[724257c]126    model.timestamp = getmtime(path)
[ed10b57]127
[2f2c70c]128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
[bf8c271]133
[2f2c70c]134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
[724257c]139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
[ed10b57]141
[92d38285]142    MODELS[model.name] = model
[724257c]143    MODEL_BY_PATH[path] = model
[92d38285]144    return model
[4d76711]145
[87985ca]146
[bcdd6c9]147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
[4d76711]161def _make_standard_model(name):
[60f03de]162    # type: (str) -> SasviewModelType
[ff7119b]163    """
[4d76711]164    Load the sasview model defined by *name*.
[72a081d]165
[4d76711]166    *name* can be a standard model name or a path to a custom model.
[87985ca]167
[4d76711]168    Returns a class that can be used directly as a sasview model.
[ff7119b]169    """
[4d76711]170    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]171    model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]172    return make_model_from_info(model_info)
[72a081d]173
174
[724257c]175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188
189    import sas.models
190    from sasmodels.conversion_table import CONVERSION_TABLE
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
[a80e64c]202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
[bcdd6c9]206    ConstructedModel = make_model_from_info(model_info)
[edb0f85]207    return ConstructedModel(form_factor.multiplicity)
[a80e64c]208
[ce27e21]209
[fa5fd8d]210def _generate_model_attributes(model_info):
211    # type: (ModelInfo) -> Dict[str, Any]
212    """
213    Generate the class attributes for the model.
214
215    This should include all the information necessary to query the model
216    details so that you do not need to instantiate a model to query it.
217
218    All the attributes should be immutable to avoid accidents.
219    """
220
221    # TODO: allow model to override axis labels input/output name/unit
222
[a18c5b3]223    # Process multiplicity
[fa5fd8d]224    non_fittable = []  # type: List[str]
[04045f4]225    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
226    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]227    for p in model_info.parameters.kernel_parameters:
[04045f4]228        if p.name == model_info.control:
[fa5fd8d]229            non_fittable.append(p.name)
[04045f4]230            variants = MultiplicityInfo(
[ce176ca]231                len(p.choices) if p.choices else int(p.limits[1]),
232                p.name, p.choices, xlabel
[fa5fd8d]233            )
234            break
235
[50ec515]236    # Only a single drop-down list parameter available
237    fun_list = []
238    for p in model_info.parameters.kernel_parameters:
239        if p.choices:
240            fun_list = p.choices
241            if p.length > 1:
242                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
243            break
244
[a18c5b3]245    # Organize parameter sets
[fa5fd8d]246    orientation_params = []
247    magnetic_params = []
248    fixed = []
[85fe7f8]249    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]250        if p.type == 'orientation':
251            orientation_params.append(p.name)
252            orientation_params.append(p.name+".width")
253            fixed.append(p.name+".width")
[32e3c9b]254        elif p.type == 'magnetic':
[fa5fd8d]255            orientation_params.append(p.name)
256            magnetic_params.append(p.name)
257            fixed.append(p.name+".width")
[a18c5b3]258
[32e3c9b]259
[a18c5b3]260    # Build class dictionary
261    attrs = {}  # type: Dict[str, Any]
262    attrs['_model_info'] = model_info
263    attrs['name'] = model_info.name
264    attrs['id'] = model_info.id
265    attrs['description'] = model_info.description
266    attrs['category'] = model_info.category
267    attrs['is_structure_factor'] = model_info.structure_factor
268    attrs['is_form_factor'] = model_info.ER is not None
269    attrs['is_multiplicity_model'] = variants[0] > 1
270    attrs['multiplicity_info'] = variants
[fa5fd8d]271    attrs['orientation_params'] = tuple(orientation_params)
272    attrs['magnetic_params'] = tuple(magnetic_params)
273    attrs['fixed'] = tuple(fixed)
274    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]275    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]276
277    return attrs
[4d76711]278
[ce27e21]279class SasviewModel(object):
280    """
281    Sasview wrapper for opencl/ctypes model.
282    """
[fa5fd8d]283    # Model parameters for the specific model are set in the class constructor
284    # via the _generate_model_attributes function, which subclasses
285    # SasviewModel.  They are included here for typing and documentation
286    # purposes.
287    _model = None       # type: KernelModel
288    _model_info = None  # type: ModelInfo
289    #: load/save name for the model
290    id = None           # type: str
291    #: display name for the model
292    name = None         # type: str
293    #: short model description
294    description = None  # type: str
295    #: default model category
296    category = None     # type: str
297
298    #: names of the orientation parameters in the order they appear
[724257c]299    orientation_params = None # type: List[str]
[fa5fd8d]300    #: names of the magnetic parameters in the order they appear
[724257c]301    magnetic_params = None    # type: List[str]
[fa5fd8d]302    #: names of the fittable parameters
[724257c]303    fixed = None              # type: List[str]
[fa5fd8d]304    # TODO: the attribute fixed is ill-named
305
306    # Axis labels
307    input_name = "Q"
308    input_unit = "A^{-1}"
309    output_name = "Intensity"
310    output_unit = "cm^{-1}"
311
312    #: default cutoff for polydispersity
313    cutoff = 1e-5
314
315    # Note: Use non-mutable values for class attributes to avoid errors
316    #: parameters that are not fitted
317    non_fittable = ()        # type: Sequence[str]
318
319    #: True if model should appear as a structure factor
320    is_structure_factor = False
321    #: True if model should appear as a form factor
322    is_form_factor = False
323    #: True if model has multiplicity
324    is_multiplicity_model = False
[1f35235]325    #: Multiplicity information
[fa5fd8d]326    multiplicity_info = None # type: MultiplicityInfoType
327
328    # Per-instance variables
329    #: parameter {name: value} mapping
330    params = None      # type: Dict[str, float]
331    #: values for dispersion width, npts, nsigmas and type
332    dispersion = None  # type: Dict[str, Any]
333    #: units and limits for each parameter
[60f03de]334    details = None     # type: Dict[str, Sequence[Any]]
335    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]336    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]337    multiplicity = None     # type: Optional[int]
[04dc697]338    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
339    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]340
341    def __init__(self, multiplicity=None):
[04dc697]342        # type: (Optional[int]) -> None
[2622b3f]343
[04045f4]344        # TODO: _persistency_dict to persistency_dict throughout sasview
345        # TODO: refactor multiplicity to encompass variants
346        # TODO: dispersion should be a class
[fa5fd8d]347        # TODO: refactor multiplicity info
348        # TODO: separate profile view from multiplicity
349        # The button label, x and y axis labels and scale need to be under
350        # the control of the model, not the fit page.  Maximum flexibility,
351        # the fit page would supply the canvas and the profile could plot
352        # how it wants, but this assumes matplotlib.  Next level is that
353        # we provide some sort of data description including title, labels
354        # and lines to plot.
355
[1f35235]356        # Get the list of hidden parameters given the multiplicity
[04045f4]357        # Don't include multiplicity in the list of parameters
[fa5fd8d]358        self.multiplicity = multiplicity
[04045f4]359        if multiplicity is not None:
360            hidden = self._model_info.get_hidden_parameters(multiplicity)
361            hidden |= set([self.multiplicity_info.control])
362        else:
363            hidden = set()
[8f93522]364        if self._model_info.structure_factor:
365            hidden.add('scale')
366            hidden.add('background')
367            self._model_info.parameters.defaults['background'] = 0.
[04045f4]368
[04dc697]369        self._persistency_dict = {}
[fa5fd8d]370        self.params = collections.OrderedDict()
[b3a85cd]371        self.dispersion = collections.OrderedDict()
[fa5fd8d]372        self.details = {}
[8977226]373        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]374            if p.name in hidden:
[fa5fd8d]375                continue
[fcd7bbd]376            self.params[p.name] = p.default
[fa5fd8d]377            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]378            if p.polydisperse:
[fa5fd8d]379                self.details[p.id+".width"] = [
380                    "", 0.0, 1.0 if p.relative_pd else np.inf
381                ]
[fb5914f]382                self.dispersion[p.name] = {
383                    'width': 0,
384                    'npts': 35,
385                    'nsigmas': 3,
386                    'type': 'gaussian',
387                }
[ce27e21]388
[de97440]389    def __get_state__(self):
[fa5fd8d]390        # type: () -> Dict[str, Any]
[de97440]391        state = self.__dict__.copy()
[4d76711]392        state.pop('_model')
[de97440]393        # May need to reload model info on set state since it has pointers
394        # to python implementations of Iq, etc.
395        #state.pop('_model_info')
396        return state
397
398    def __set_state__(self, state):
[fa5fd8d]399        # type: (Dict[str, Any]) -> None
[de97440]400        self.__dict__ = state
[fb5914f]401        self._model = None
[de97440]402
[ce27e21]403    def __str__(self):
[fa5fd8d]404        # type: () -> str
[ce27e21]405        """
406        :return: string representation
407        """
408        return self.name
409
410    def is_fittable(self, par_name):
[fa5fd8d]411        # type: (str) -> bool
[ce27e21]412        """
413        Check if a given parameter is fittable or not
414
415        :param par_name: the parameter name to check
416        """
[e758662]417        return par_name in self.fixed
[ce27e21]418        #For the future
419        #return self.params[str(par_name)].is_fittable()
420
421
422    def getProfile(self):
[fa5fd8d]423        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]424        """
425        Get SLD profile
426
427        : return: (z, beta) where z is a list of depth of the transition points
428                beta is a list of the corresponding SLD values
429        """
[745b7bb]430        args = {} # type: Dict[str, Any]
[fa5fd8d]431        for p in self._model_info.parameters.kernel_parameters:
432            if p.id == self.multiplicity_info.control:
[745b7bb]433                value = float(self.multiplicity)
[fa5fd8d]434            elif p.length == 1:
[745b7bb]435                value = self.params.get(p.id, np.NaN)
[fa5fd8d]436            else:
[745b7bb]437                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]438                                  for k in range(1, p.length+1)])
[745b7bb]439            args[p.id] = value
440
[e7fe459]441        x, y = self._model_info.profile(**args)
442        return x, 1e-6*y
[ce27e21]443
444    def setParam(self, name, value):
[fa5fd8d]445        # type: (str, float) -> None
[ce27e21]446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450        :param value: value of the parameter
451
452        """
453        # Look for dispersion parameters
454        toks = name.split('.')
[de0c4ba]455        if len(toks) == 2:
[ce27e21]456            for item in self.dispersion.keys():
[e758662]457                if item == toks[0]:
[ce27e21]458                    for par in self.dispersion[item]:
[e758662]459                        if par == toks[1]:
[ce27e21]460                            self.dispersion[item][par] = value
461                            return
462        else:
463            # Look for standard parameter
464            for item in self.params.keys():
[e758662]465                if item == name:
[ce27e21]466                    self.params[item] = value
467                    return
468
[63b32bb]469        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]470
471    def getParam(self, name):
[fa5fd8d]472        # type: (str) -> float
[ce27e21]473        """
474        Set the value of a model parameter
475
476        :param name: name of the parameter
477
478        """
479        # Look for dispersion parameters
480        toks = name.split('.')
[de0c4ba]481        if len(toks) == 2:
[ce27e21]482            for item in self.dispersion.keys():
[e758662]483                if item == toks[0]:
[ce27e21]484                    for par in self.dispersion[item]:
[e758662]485                        if par == toks[1]:
[ce27e21]486                            return self.dispersion[item][par]
487        else:
488            # Look for standard parameter
489            for item in self.params.keys():
[e758662]490                if item == name:
[ce27e21]491                    return self.params[item]
492
[63b32bb]493        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]494
495    def getParamList(self):
[04dc697]496        # type: () -> Sequence[str]
[ce27e21]497        """
498        Return a list of all available parameters for the model
499        """
[04dc697]500        param_list = list(self.params.keys())
[ce27e21]501        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]502        param_list.extend(self.getDispParamList())
503        return param_list
[ce27e21]504
505    def getDispParamList(self):
[04dc697]506        # type: () -> Sequence[str]
[ce27e21]507        """
[fb5914f]508        Return a list of polydispersity parameters for the model
[ce27e21]509        """
[1780d59]510        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]511        ret = ['%s.%s' % (p_name, ext)
512               for p_name in self.dispersion.keys()
513               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]514        #print(ret)
[1780d59]515        return ret
[ce27e21]516
517    def clone(self):
[04dc697]518        # type: () -> "SasviewModel"
[ce27e21]519        """ Return a identical copy of self """
520        return deepcopy(self)
521
522    def run(self, x=0.0):
[fa5fd8d]523        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]524        """
525        Evaluate the model
526
527        :param x: input q, or [q,phi]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
[de0c4ba]533        if isinstance(x, (list, tuple)):
[3c56da87]534            # pylint: disable=unpacking-non-sequence
[ce27e21]535            q, phi = x
[60f03de]536            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]537        else:
[60f03de]538            return self.calculate_Iq([x])[0]
[ce27e21]539
540
541    def runXY(self, x=0.0):
[fa5fd8d]542        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]543        """
544        Evaluate the model in cartesian coordinates
545
546        :param x: input q, or [qx, qy]
547
548        :return: scattering function P(q)
549
550        **DEPRECATED**: use calculate_Iq instead
551        """
[de0c4ba]552        if isinstance(x, (list, tuple)):
[60f03de]553            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]554        else:
[60f03de]555            return self.calculate_Iq([x])[0]
[ce27e21]556
557    def evalDistribution(self, qdist):
[04dc697]558        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]559        r"""
[ce27e21]560        Evaluate a distribution of q-values.
561
[d138d43]562        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]563
[d138d43]564        * For 1D, a numpy array is expected as input
[ce27e21]565
[d138d43]566        ::
[ce27e21]567
[d138d43]568            evalDistribution(q)
[ce27e21]569
[d138d43]570          where *q* is a numpy array.
[ce27e21]571
[d138d43]572        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]573
[d138d43]574        ::
[ce27e21]575
[d138d43]576              qx = [ qx[0], qx[1], qx[2], ....]
577              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]578
[d138d43]579        If the model is 1D only, then
[ce27e21]580
[d138d43]581        .. math::
[ce27e21]582
[d138d43]583            q = \sqrt{q_x^2+q_y^2}
[ce27e21]584
585        """
[de0c4ba]586        if isinstance(qdist, (list, tuple)):
[ce27e21]587            # Check whether we have a list of ndarrays [qx,qy]
588            qx, qy = qdist
[6d6508e]589            if not self._model_info.parameters.has_2d:
[de0c4ba]590                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]591            else:
592                return self.calculate_Iq(qx, qy)
[ce27e21]593
594        elif isinstance(qdist, np.ndarray):
595            # We have a simple 1D distribution of q-values
596            return self.calculate_Iq(qdist)
597
598        else:
[3c56da87]599            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
600                            % type(qdist))
[ce27e21]601
[9dcb21d]602    def calc_composition_models(self, qx):
[64614ad]603        """
[9dcb21d]604        returns parts of the composition model or None if not a composition
605        model.
[64614ad]606        """
[946c8d27]607        # TODO: have calculate_Iq return the intermediates.
608        #
609        # The current interface causes calculate_Iq() to be called twice,
610        # once to get the combined result and again to get the intermediate
611        # results.  This is necessary for now.
612        # Long term, the solution is to change the interface to calculate_Iq
613        # so that it returns a results object containing all the bits:
[9644b5a]614        #     the A, B, C, ... of the composition model (and any subcomponents?)
[946c8d27]615        #     the P and S of the product model,
616        #     the combined model before resolution smearing,
617        #     the sasmodel before sesans conversion,
618        #     the oriented 2D model used to fit oriented usans data,
619        #     the final I(q),
620        #     ...
[9644b5a]621        #
[946c8d27]622        # Have the model calculator add all of these blindly to the data
623        # tree, and update the graphs which contain them.  The fitter
624        # needs to be updated to use the I(q) value only, ignoring the rest.
625        #
626        # The simple fix of returning the existing intermediate results
627        # will not work for a couple of reasons: (1) another thread may
628        # sneak in to compute its own results before calc_composition_models
629        # is called, and (2) calculate_Iq is currently called three times:
630        # once with q, once with q values before qmin and once with q values
631        # after q max.  Both of these should be addressed before
632        # replacing this code.
[9644b5a]633        composition = self._model_info.composition
634        if composition and composition[0] == 'product': # only P*S for now
635            with calculation_lock:
636                self._calculate_Iq(qx)
637                return self._intermediate_results
638        else:
639            return None
[bf8c271]640
[fa5fd8d]641    def calculate_Iq(self, qx, qy=None):
642        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]643        """
644        Calculate Iq for one set of q with the current parameters.
645
646        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
647
648        This should NOT be used for fitting since it copies the *q* vectors
649        to the card for each evaluation.
650        """
[a38b065]651        ## uncomment the following when trying to debug the uncoordinated calls
652        ## to calculate_Iq
653        #if calculation_lock.locked():
[724257c]654        #    logger.info("calculation waiting for another thread to complete")
655        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]656
657        with calculation_lock:
658            return self._calculate_Iq(qx, qy)
659
660    def _calculate_Iq(self, qx, qy=None):
[6a0d6aa]661        #core.HAVE_OPENCL = False
[fb5914f]662        if self._model is None:
[d2bb604]663            self._model = core.build_model(self._model_info)
[fa5fd8d]664        if qy is not None:
665            q_vectors = [np.asarray(qx), np.asarray(qy)]
666        else:
667            q_vectors = [np.asarray(qx)]
[a738209]668        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]669        parameters = self._model_info.parameters
670        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]671        #weights.plot_weights(self._model_info, pairs)
[bde38b5]672        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]673        #call_details.show()
674        #print("pairs", pairs)
675        #print("params", self.params)
676        #print("values", values)
677        #print("is_mag", is_magnetic)
[6a0d6aa]678        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]679                            magnetic=is_magnetic)
[bf8c271]680        self._intermediate_results = getattr(calculator, 'results', None)
[a738209]681        calculator.release()
[9f37726]682        self._model.release()
[ce27e21]683        return result
684
685    def calculate_ER(self):
[fa5fd8d]686        # type: () -> float
[ce27e21]687        """
688        Calculate the effective radius for P(q)*S(q)
689
690        :return: the value of the effective radius
691        """
[4bfd277]692        if self._model_info.ER is None:
[ce27e21]693            return 1.0
694        else:
[4bfd277]695            value, weight = self._dispersion_mesh()
696            fv = self._model_info.ER(*value)
[9404dd3]697            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]698            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]699
700    def calculate_VR(self):
[fa5fd8d]701        # type: () -> float
[ce27e21]702        """
703        Calculate the volf ratio for P(q)*S(q)
704
705        :return: the value of the volf ratio
706        """
[4bfd277]707        if self._model_info.VR is None:
[ce27e21]708            return 1.0
709        else:
[4bfd277]710            value, weight = self._dispersion_mesh()
711            whole, part = self._model_info.VR(*value)
712            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]713
714    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]715        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]716        """
717        Set the dispersion object for a model parameter
718
719        :param parameter: name of the parameter [string]
720        :param dispersion: dispersion object of type Dispersion
721        """
[fa800e72]722        if parameter in self.params:
[1780d59]723            # TODO: Store the disperser object directly in the model.
[56b2687]724            # The current method of relying on the sasview GUI to
[fa800e72]725            # remember them is kind of funky.
[1780d59]726            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]727            # (1) Could create a sasview model that has not yet been
[1780d59]728            # converted, assign the disperser to one of its polydisperse
729            # parameters, then retrieve the disperser parameters from the
[9c1a59c]730            # sasview model.
731            # (2) Could write a disperser parameter retriever in sasview.
732            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]733            # For now, rely on the fact that the sasview only ever uses
734            # new dispersers in the set_dispersion call and create a new
735            # one instead of trying to assign parameters.
[ce27e21]736            self.dispersion[parameter] = dispersion.get_pars()
737        else:
738            raise ValueError("%r is not a dispersity or orientation parameter")
739
[aa4946b]740    def _dispersion_mesh(self):
[fa5fd8d]741        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]742        """
743        Create a mesh grid of dispersion parameters and weights.
744
745        Returns [p1,p2,...],w where pj is a vector of values for parameter j
746        and w is a vector containing the products for weights for each
747        parameter set in the vector.
748        """
[4bfd277]749        pars = [self._get_weights(p)
750                for p in self._model_info.parameters.call_parameters
751                if p.type == 'volume']
[9eb3632]752        return dispersion_mesh(self._model_info, pars)
[ce27e21]753
754    def _get_weights(self, par):
[fa5fd8d]755        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]756        """
[fb5914f]757        Return dispersion weights for parameter
[de0c4ba]758        """
[fa5fd8d]759        if par.name not in self.params:
760            if par.name == self.multiplicity_info.control:
[4edec6f]761                return [self.multiplicity], [1.0]
[fa5fd8d]762            else:
[8f93522]763                # For hidden parameters use the default value.
764                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
765                return [value], [1.0]
[fa5fd8d]766        elif par.polydisperse:
[fb5914f]767            dis = self.dispersion[par.name]
[9c1a59c]768            if dis['type'] == 'array':
769                value, weight = dis['values'], dis['weights']
770            else:
771                value, weight = weights.get_weights(
772                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
773                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]774            return value, weight / np.sum(weight)
775        else:
[4edec6f]776            return [self.params[par.name]], [1.0]
[ce27e21]777
[749a7d4]778def test_cylinder():
[fa5fd8d]779    # type: () -> float
[4d76711]780    """
[749a7d4]781    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]782    """
783    Cylinder = _make_standard_model('cylinder')
[fb5914f]784    cylinder = Cylinder()
[b32dafd]785    return cylinder.evalDistribution([0.1, 0.1])
[de97440]786
[8f93522]787def test_structure_factor():
788    # type: () -> float
789    """
[749a7d4]790    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]791    """
792    Model = _make_standard_model('hardsphere')
793    model = Model()
794    value = model.evalDistribution([0.1, 0.1])
795    if np.isnan(value):
796        raise ValueError("hardsphere returns null")
797
[04045f4]798def test_rpa():
799    # type: () -> float
800    """
[749a7d4]801    Test that the 2-D RPA model runs
[04045f4]802    """
803    RPA = _make_standard_model('rpa')
804    rpa = RPA(3)
[b32dafd]805    return rpa.evalDistribution([0.1, 0.1])
[04045f4]806
[749a7d4]807def test_empty_distribution():
808    # type: () -> None
809    """
810    Make sure that sasmodels returns NaN when there are no polydispersity points
811    """
812    Cylinder = _make_standard_model('cylinder')
813    cylinder = Cylinder()
814    cylinder.setParam('radius', -1.0)
815    cylinder.setParam('background', 0.)
816    Iq = cylinder.evalDistribution(np.asarray([0.1]))
817    assert np.isnan(Iq[0]), "empty distribution fails"
[4d76711]818
819def test_model_list():
[fa5fd8d]820    # type: () -> None
[4d76711]821    """
[749a7d4]822    Make sure that all models build as sasview models
[4d76711]823    """
824    from .exception import annotate_exception
825    for name in core.list_models():
826        try:
827            _make_standard_model(name)
828        except:
829            annotate_exception("when loading "+name)
830            raise
831
[c95dfc63]832def test_old_name():
833    # type: () -> None
834    """
835    Load and run cylinder model from sas.models.CylinderModel
836    """
837    if not SUPPORT_OLD_STYLE_PLUGINS:
838        return
839    try:
840        # if sasview is not on the path then don't try to test it
841        import sas
842    except ImportError:
843        return
844    load_standard_models()
845    from sas.models.CylinderModel import CylinderModel
846    CylinderModel().evalDistribution([0.1, 0.1])
847
[fb5914f]848if __name__ == "__main__":
[749a7d4]849    print("cylinder(0.1,0.1)=%g"%test_cylinder())
850    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.