source: sasmodels/sasmodels/sasview_model.py @ 17db833

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 17db833 was 17db833, checked in by Paul Kienzle <pkienzle@…>, 6 years ago

fix structure factor models, so scale=1, background=0 again

  • Property mode set to 100644
File size: 30.7 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[9f8ade1]18try:
19    import _thread as thread
20except ImportError:
21    import thread
[ce27e21]22
[7ae2b7f]23import numpy as np  # type: ignore
[ce27e21]24
[aa4946b]25from . import core
[4d76711]26from . import custom
[a80e64c]27from . import product
[72a081d]28from . import generate
[fb5914f]29from . import weights
[6d6508e]30from . import modelinfo
[bde38b5]31from .details import make_kernel_args, dispersion_mesh
[ff7119b]32
[fa5fd8d]33try:
[60f03de]34    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]35    from .modelinfo import ModelInfo, Parameter
36    from .kernel import KernelModel
37    MultiplicityInfoType = NamedTuple(
[a9bc435]38        'MultiplicityInfo',
[fa5fd8d]39        [("number", int), ("control", str), ("choices", List[str]),
40         ("x_axis_label", str)])
[60f03de]41    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]42except ImportError:
43    pass
44
[724257c]45logger = logging.getLogger(__name__)
46
[a38b065]47calculation_lock = thread.allocate_lock()
48
[724257c]49#: True if pre-existing plugins, with the old names and parameters, should
50#: continue to be supported.
[c95dfc63]51SUPPORT_OLD_STYLE_PLUGINS = True
52
[fa5fd8d]53# TODO: separate x_axis_label from multiplicity info
54MultiplicityInfo = collections.namedtuple(
55    'MultiplicityInfo',
56    ["number", "control", "choices", "x_axis_label"],
57)
58
[724257c]59#: set of defined models (standard and custom)
60MODELS = {}  # type: Dict[str, SasviewModelType]
61#: custom model {path: model} mapping so we can check timestamps
62MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
63
[92d38285]64def find_model(modelname):
[b32dafd]65    # type: (str) -> SasviewModelType
66    """
67    Find a model by name.  If the model name ends in py, try loading it from
68    custom models, otherwise look for it in the list of builtin models.
69    """
[92d38285]70    # TODO: used by sum/product model to load an existing model
71    # TODO: doesn't handle custom models properly
72    if modelname.endswith('.py'):
73        return load_custom_model(modelname)
74    elif modelname in MODELS:
75        return MODELS[modelname]
76    else:
77        raise ValueError("unknown model %r"%modelname)
78
[56b2687]79
[fa5fd8d]80# TODO: figure out how to say that the return type is a subclass
[4d76711]81def load_standard_models():
[60f03de]82    # type: () -> List[SasviewModelType]
[4d76711]83    """
84    Load and return the list of predefined models.
85
86    If there is an error loading a model, then a traceback is logged and the
87    model is not returned.
88    """
89    for name in core.list_models():
90        try:
[92d38285]91            MODELS[name] = _make_standard_model(name)
[ee8f734]92        except Exception:
[724257c]93            logger.error(traceback.format_exc())
[c95dfc63]94    if SUPPORT_OLD_STYLE_PLUGINS:
95        _register_old_models()
96
[724257c]97    return list(MODELS.values())
[de97440]98
[4d76711]99
100def load_custom_model(path):
[60f03de]101    # type: (str) -> SasviewModelType
[4d76711]102    """
103    Load a custom model given the model path.
[ff7119b]104    """
[724257c]105    model = MODEL_BY_PATH.get(path, None)
106    if model is not None and model.timestamp == getmtime(path):
107        #logger.info("Model already loaded %s", path)
108        return model
109
110    #logger.info("Loading model %s", path)
[4d76711]111    kernel_module = custom.load_custom_kernel_module(path)
[724257c]112    if hasattr(kernel_module, 'Model'):
[92d38285]113        model = kernel_module.Model
[9457498]114        # Old style models do not set the name in the class attributes, so
115        # set it here; this name will be overridden when the object is created
116        # with an instance variable that has the same value.
117        if model.name == "":
118            model.name = splitext(basename(path))[0]
[20a70bc]119        if not hasattr(model, 'filename'):
[724257c]120            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]121        if not hasattr(model, 'id'):
122            model.id = splitext(basename(model.filename))[0]
[724257c]123    else:
[56b2687]124        model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]125        model = make_model_from_info(model_info)
[724257c]126    model.timestamp = getmtime(path)
[ed10b57]127
[2f2c70c]128    # If a model name already exists and we are loading a different model,
129    # use the model file name as the model name.
130    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
131        _previous_name = model.name
132        model.name = model.id
[bf8c271]133
[2f2c70c]134        # If the new model name is still in the model list (for instance,
135        # if we put a cylinder.py in our plug-in directory), then append
136        # an identifier.
137        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
138            model.name = model.id + '_user'
[724257c]139        logger.info("Model %s already exists: using %s [%s]",
140                    _previous_name, model.name, model.filename)
[ed10b57]141
[92d38285]142    MODELS[model.name] = model
[724257c]143    MODEL_BY_PATH[path] = model
[92d38285]144    return model
[4d76711]145
[87985ca]146
[bcdd6c9]147def make_model_from_info(model_info):
148    # type: (ModelInfo) -> SasviewModelType
149    """
150    Convert *model_info* into a SasView model wrapper.
151    """
152    def __init__(self, multiplicity=None):
153        SasviewModel.__init__(self, multiplicity=multiplicity)
154    attrs = _generate_model_attributes(model_info)
155    attrs['__init__'] = __init__
156    attrs['filename'] = model_info.filename
157    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
158    return ConstructedModel
159
160
[4d76711]161def _make_standard_model(name):
[60f03de]162    # type: (str) -> SasviewModelType
[ff7119b]163    """
[4d76711]164    Load the sasview model defined by *name*.
[72a081d]165
[4d76711]166    *name* can be a standard model name or a path to a custom model.
[87985ca]167
[4d76711]168    Returns a class that can be used directly as a sasview model.
[ff7119b]169    """
[4d76711]170    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]171    model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]172    return make_model_from_info(model_info)
[72a081d]173
174
[724257c]175def _register_old_models():
176    # type: () -> None
177    """
178    Place the new models into sasview under the old names.
179
180    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
181    is available to the plugin modules.
182    """
183    import sys
184    import sas   # needed in order to set sas.models
185    import sas.sascalc.fit
186    sys.modules['sas.models'] = sas.sascalc.fit
187    sas.models = sas.sascalc.fit
188
189    import sas.models
190    from sasmodels.conversion_table import CONVERSION_TABLE
191    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
192        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
193        new_name = new_name.split(':')[0]
194        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
195        module_attrs = {old_name: find_model(new_name)}
196        ConstructedModule = type(old_name, (), module_attrs)
197        old_path = 'sas.models.' + old_name
198        setattr(sas.models, old_path, ConstructedModule)
199        sys.modules[old_path] = ConstructedModule
200
201
[a80e64c]202def MultiplicationModel(form_factor, structure_factor):
203    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
204    model_info = product.make_product_info(form_factor._model_info,
205                                           structure_factor._model_info)
[bcdd6c9]206    ConstructedModel = make_model_from_info(model_info)
[a06af5d]207    return ConstructedModel(form_factor.multiplicity)
[a80e64c]208
[ce27e21]209
[fa5fd8d]210def _generate_model_attributes(model_info):
211    # type: (ModelInfo) -> Dict[str, Any]
212    """
213    Generate the class attributes for the model.
214
215    This should include all the information necessary to query the model
216    details so that you do not need to instantiate a model to query it.
217
218    All the attributes should be immutable to avoid accidents.
219    """
220
221    # TODO: allow model to override axis labels input/output name/unit
222
[a18c5b3]223    # Process multiplicity
[fa5fd8d]224    non_fittable = []  # type: List[str]
[04045f4]225    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
226    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]227    for p in model_info.parameters.kernel_parameters:
[04045f4]228        if p.name == model_info.control:
[fa5fd8d]229            non_fittable.append(p.name)
[04045f4]230            variants = MultiplicityInfo(
[ce176ca]231                len(p.choices) if p.choices else int(p.limits[1]),
232                p.name, p.choices, xlabel
[fa5fd8d]233            )
234            break
235
[50ec515]236    # Only a single drop-down list parameter available
237    fun_list = []
238    for p in model_info.parameters.kernel_parameters:
239        if p.choices:
240            fun_list = p.choices
241            if p.length > 1:
242                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
243            break
244
[a18c5b3]245    # Organize parameter sets
[fa5fd8d]246    orientation_params = []
247    magnetic_params = []
248    fixed = []
[85fe7f8]249    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]250        if p.type == 'orientation':
251            orientation_params.append(p.name)
252            orientation_params.append(p.name+".width")
253            fixed.append(p.name+".width")
[32e3c9b]254        elif p.type == 'magnetic':
[fa5fd8d]255            orientation_params.append(p.name)
256            magnetic_params.append(p.name)
257            fixed.append(p.name+".width")
[a18c5b3]258
[32e3c9b]259
[a18c5b3]260    # Build class dictionary
261    attrs = {}  # type: Dict[str, Any]
262    attrs['_model_info'] = model_info
263    attrs['name'] = model_info.name
264    attrs['id'] = model_info.id
265    attrs['description'] = model_info.description
266    attrs['category'] = model_info.category
267    attrs['is_structure_factor'] = model_info.structure_factor
268    attrs['is_form_factor'] = model_info.ER is not None
269    attrs['is_multiplicity_model'] = variants[0] > 1
270    attrs['multiplicity_info'] = variants
[fa5fd8d]271    attrs['orientation_params'] = tuple(orientation_params)
272    attrs['magnetic_params'] = tuple(magnetic_params)
273    attrs['fixed'] = tuple(fixed)
274    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]275    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]276
277    return attrs
[4d76711]278
[ce27e21]279class SasviewModel(object):
280    """
281    Sasview wrapper for opencl/ctypes model.
282    """
[fa5fd8d]283    # Model parameters for the specific model are set in the class constructor
284    # via the _generate_model_attributes function, which subclasses
285    # SasviewModel.  They are included here for typing and documentation
286    # purposes.
287    _model = None       # type: KernelModel
288    _model_info = None  # type: ModelInfo
289    #: load/save name for the model
290    id = None           # type: str
291    #: display name for the model
292    name = None         # type: str
293    #: short model description
294    description = None  # type: str
295    #: default model category
296    category = None     # type: str
297
298    #: names of the orientation parameters in the order they appear
[724257c]299    orientation_params = None # type: List[str]
[fa5fd8d]300    #: names of the magnetic parameters in the order they appear
[724257c]301    magnetic_params = None    # type: List[str]
[fa5fd8d]302    #: names of the fittable parameters
[724257c]303    fixed = None              # type: List[str]
[fa5fd8d]304    # TODO: the attribute fixed is ill-named
305
306    # Axis labels
307    input_name = "Q"
308    input_unit = "A^{-1}"
309    output_name = "Intensity"
310    output_unit = "cm^{-1}"
311
312    #: default cutoff for polydispersity
313    cutoff = 1e-5
314
315    # Note: Use non-mutable values for class attributes to avoid errors
316    #: parameters that are not fitted
317    non_fittable = ()        # type: Sequence[str]
318
319    #: True if model should appear as a structure factor
320    is_structure_factor = False
321    #: True if model should appear as a form factor
322    is_form_factor = False
323    #: True if model has multiplicity
324    is_multiplicity_model = False
[1f35235]325    #: Multiplicity information
[fa5fd8d]326    multiplicity_info = None # type: MultiplicityInfoType
327
328    # Per-instance variables
329    #: parameter {name: value} mapping
330    params = None      # type: Dict[str, float]
331    #: values for dispersion width, npts, nsigmas and type
332    dispersion = None  # type: Dict[str, Any]
333    #: units and limits for each parameter
[60f03de]334    details = None     # type: Dict[str, Sequence[Any]]
335    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]336    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]337    multiplicity = None     # type: Optional[int]
[04dc697]338    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
339    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]340
341    def __init__(self, multiplicity=None):
[04dc697]342        # type: (Optional[int]) -> None
[2622b3f]343
[04045f4]344        # TODO: _persistency_dict to persistency_dict throughout sasview
345        # TODO: refactor multiplicity to encompass variants
346        # TODO: dispersion should be a class
[fa5fd8d]347        # TODO: refactor multiplicity info
348        # TODO: separate profile view from multiplicity
349        # The button label, x and y axis labels and scale need to be under
350        # the control of the model, not the fit page.  Maximum flexibility,
351        # the fit page would supply the canvas and the profile could plot
352        # how it wants, but this assumes matplotlib.  Next level is that
353        # we provide some sort of data description including title, labels
354        # and lines to plot.
355
[1f35235]356        # Get the list of hidden parameters given the multiplicity
[04045f4]357        # Don't include multiplicity in the list of parameters
[fa5fd8d]358        self.multiplicity = multiplicity
[04045f4]359        if multiplicity is not None:
360            hidden = self._model_info.get_hidden_parameters(multiplicity)
361            hidden |= set([self.multiplicity_info.control])
362        else:
363            hidden = set()
[8f93522]364        if self._model_info.structure_factor:
365            hidden.add('scale')
366            hidden.add('background')
367            self._model_info.parameters.defaults['background'] = 0.
[04045f4]368
[04dc697]369        self._persistency_dict = {}
[fa5fd8d]370        self.params = collections.OrderedDict()
[b3a85cd]371        self.dispersion = collections.OrderedDict()
[fa5fd8d]372        self.details = {}
[8977226]373        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]374            if p.name in hidden:
[fa5fd8d]375                continue
[fcd7bbd]376            self.params[p.name] = p.default
[fa5fd8d]377            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]378            if p.polydisperse:
[fa5fd8d]379                self.details[p.id+".width"] = [
380                    "", 0.0, 1.0 if p.relative_pd else np.inf
381                ]
[fb5914f]382                self.dispersion[p.name] = {
383                    'width': 0,
384                    'npts': 35,
385                    'nsigmas': 3,
386                    'type': 'gaussian',
387                }
[ce27e21]388
[de97440]389    def __get_state__(self):
[fa5fd8d]390        # type: () -> Dict[str, Any]
[de97440]391        state = self.__dict__.copy()
[4d76711]392        state.pop('_model')
[de97440]393        # May need to reload model info on set state since it has pointers
394        # to python implementations of Iq, etc.
395        #state.pop('_model_info')
396        return state
397
398    def __set_state__(self, state):
[fa5fd8d]399        # type: (Dict[str, Any]) -> None
[de97440]400        self.__dict__ = state
[fb5914f]401        self._model = None
[de97440]402
[ce27e21]403    def __str__(self):
[fa5fd8d]404        # type: () -> str
[ce27e21]405        """
406        :return: string representation
407        """
408        return self.name
409
410    def is_fittable(self, par_name):
[fa5fd8d]411        # type: (str) -> bool
[ce27e21]412        """
413        Check if a given parameter is fittable or not
414
415        :param par_name: the parameter name to check
416        """
[e758662]417        return par_name in self.fixed
[ce27e21]418        #For the future
419        #return self.params[str(par_name)].is_fittable()
420
421
422    def getProfile(self):
[fa5fd8d]423        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]424        """
425        Get SLD profile
426
427        : return: (z, beta) where z is a list of depth of the transition points
428                beta is a list of the corresponding SLD values
429        """
[745b7bb]430        args = {} # type: Dict[str, Any]
[fa5fd8d]431        for p in self._model_info.parameters.kernel_parameters:
432            if p.id == self.multiplicity_info.control:
[745b7bb]433                value = float(self.multiplicity)
[fa5fd8d]434            elif p.length == 1:
[745b7bb]435                value = self.params.get(p.id, np.NaN)
[fa5fd8d]436            else:
[745b7bb]437                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]438                                  for k in range(1, p.length+1)])
[745b7bb]439            args[p.id] = value
440
[e7fe459]441        x, y = self._model_info.profile(**args)
442        return x, 1e-6*y
[ce27e21]443
444    def setParam(self, name, value):
[fa5fd8d]445        # type: (str, float) -> None
[ce27e21]446        """
447        Set the value of a model parameter
448
449        :param name: name of the parameter
450        :param value: value of the parameter
451
452        """
453        # Look for dispersion parameters
454        toks = name.split('.')
[de0c4ba]455        if len(toks) == 2:
[ce27e21]456            for item in self.dispersion.keys():
[e758662]457                if item == toks[0]:
[ce27e21]458                    for par in self.dispersion[item]:
[e758662]459                        if par == toks[1]:
[ce27e21]460                            self.dispersion[item][par] = value
461                            return
462        else:
463            # Look for standard parameter
464            for item in self.params.keys():
[e758662]465                if item == name:
[ce27e21]466                    self.params[item] = value
467                    return
468
[63b32bb]469        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]470
471    def getParam(self, name):
[fa5fd8d]472        # type: (str) -> float
[ce27e21]473        """
474        Set the value of a model parameter
475
476        :param name: name of the parameter
477
478        """
479        # Look for dispersion parameters
480        toks = name.split('.')
[de0c4ba]481        if len(toks) == 2:
[ce27e21]482            for item in self.dispersion.keys():
[e758662]483                if item == toks[0]:
[ce27e21]484                    for par in self.dispersion[item]:
[e758662]485                        if par == toks[1]:
[ce27e21]486                            return self.dispersion[item][par]
487        else:
488            # Look for standard parameter
489            for item in self.params.keys():
[e758662]490                if item == name:
[ce27e21]491                    return self.params[item]
492
[63b32bb]493        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]494
495    def getParamList(self):
[04dc697]496        # type: () -> Sequence[str]
[ce27e21]497        """
498        Return a list of all available parameters for the model
499        """
[04dc697]500        param_list = list(self.params.keys())
[ce27e21]501        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]502        param_list.extend(self.getDispParamList())
503        return param_list
[ce27e21]504
505    def getDispParamList(self):
[04dc697]506        # type: () -> Sequence[str]
[ce27e21]507        """
[fb5914f]508        Return a list of polydispersity parameters for the model
[ce27e21]509        """
[1780d59]510        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]511        ret = ['%s.%s' % (p_name, ext)
512               for p_name in self.dispersion.keys()
513               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]514        #print(ret)
[1780d59]515        return ret
[ce27e21]516
517    def clone(self):
[04dc697]518        # type: () -> "SasviewModel"
[ce27e21]519        """ Return a identical copy of self """
520        return deepcopy(self)
521
522    def run(self, x=0.0):
[fa5fd8d]523        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]524        """
525        Evaluate the model
526
527        :param x: input q, or [q,phi]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
[de0c4ba]533        if isinstance(x, (list, tuple)):
[3c56da87]534            # pylint: disable=unpacking-non-sequence
[ce27e21]535            q, phi = x
[60f03de]536            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]537        else:
[60f03de]538            return self.calculate_Iq([x])[0]
[ce27e21]539
540
541    def runXY(self, x=0.0):
[fa5fd8d]542        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]543        """
544        Evaluate the model in cartesian coordinates
545
546        :param x: input q, or [qx, qy]
547
548        :return: scattering function P(q)
549
550        **DEPRECATED**: use calculate_Iq instead
551        """
[de0c4ba]552        if isinstance(x, (list, tuple)):
[60f03de]553            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]554        else:
[60f03de]555            return self.calculate_Iq([x])[0]
[ce27e21]556
557    def evalDistribution(self, qdist):
[04dc697]558        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]559        r"""
[ce27e21]560        Evaluate a distribution of q-values.
561
[d138d43]562        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]563
[d138d43]564        * For 1D, a numpy array is expected as input
[ce27e21]565
[d138d43]566        ::
[ce27e21]567
[d138d43]568            evalDistribution(q)
[ce27e21]569
[d138d43]570          where *q* is a numpy array.
[ce27e21]571
[d138d43]572        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]573
[d138d43]574        ::
[ce27e21]575
[d138d43]576              qx = [ qx[0], qx[1], qx[2], ....]
577              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]578
[d138d43]579        If the model is 1D only, then
[ce27e21]580
[d138d43]581        .. math::
[ce27e21]582
[d138d43]583            q = \sqrt{q_x^2+q_y^2}
[ce27e21]584
585        """
[de0c4ba]586        if isinstance(qdist, (list, tuple)):
[ce27e21]587            # Check whether we have a list of ndarrays [qx,qy]
588            qx, qy = qdist
[6d6508e]589            if not self._model_info.parameters.has_2d:
[de0c4ba]590                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]591            else:
592                return self.calculate_Iq(qx, qy)
[ce27e21]593
594        elif isinstance(qdist, np.ndarray):
595            # We have a simple 1D distribution of q-values
596            return self.calculate_Iq(qdist)
597
598        else:
[3c56da87]599            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
600                            % type(qdist))
[ce27e21]601
[9dcb21d]602    def calc_composition_models(self, qx):
[64614ad]603        """
[9dcb21d]604        returns parts of the composition model or None if not a composition
605        model.
[64614ad]606        """
[946c8d27]607        # TODO: have calculate_Iq return the intermediates.
608        #
609        # The current interface causes calculate_Iq() to be called twice,
610        # once to get the combined result and again to get the intermediate
611        # results.  This is necessary for now.
612        # Long term, the solution is to change the interface to calculate_Iq
613        # so that it returns a results object containing all the bits:
[9644b5a]614        #     the A, B, C, ... of the composition model (and any subcomponents?)
[946c8d27]615        #     the P and S of the product model,
616        #     the combined model before resolution smearing,
617        #     the sasmodel before sesans conversion,
618        #     the oriented 2D model used to fit oriented usans data,
619        #     the final I(q),
620        #     ...
[9644b5a]621        #
[946c8d27]622        # Have the model calculator add all of these blindly to the data
623        # tree, and update the graphs which contain them.  The fitter
624        # needs to be updated to use the I(q) value only, ignoring the rest.
625        #
626        # The simple fix of returning the existing intermediate results
627        # will not work for a couple of reasons: (1) another thread may
628        # sneak in to compute its own results before calc_composition_models
629        # is called, and (2) calculate_Iq is currently called three times:
630        # once with q, once with q values before qmin and once with q values
631        # after q max.  Both of these should be addressed before
632        # replacing this code.
[9644b5a]633        composition = self._model_info.composition
634        if composition and composition[0] == 'product': # only P*S for now
635            with calculation_lock:
636                self._calculate_Iq(qx)
637                return self._intermediate_results
638        else:
639            return None
[bf8c271]640
[fa5fd8d]641    def calculate_Iq(self, qx, qy=None):
642        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]643        """
644        Calculate Iq for one set of q with the current parameters.
645
646        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
647
648        This should NOT be used for fitting since it copies the *q* vectors
649        to the card for each evaluation.
650        """
[a38b065]651        ## uncomment the following when trying to debug the uncoordinated calls
652        ## to calculate_Iq
653        #if calculation_lock.locked():
[724257c]654        #    logger.info("calculation waiting for another thread to complete")
655        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]656
657        with calculation_lock:
658            return self._calculate_Iq(qx, qy)
659
660    def _calculate_Iq(self, qx, qy=None):
[6a0d6aa]661        #core.HAVE_OPENCL = False
[fb5914f]662        if self._model is None:
[d2bb604]663            self._model = core.build_model(self._model_info)
[fa5fd8d]664        if qy is not None:
665            q_vectors = [np.asarray(qx), np.asarray(qy)]
666        else:
667            q_vectors = [np.asarray(qx)]
[a738209]668        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]669        parameters = self._model_info.parameters
670        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]671        #weights.plot_weights(self._model_info, pairs)
[bde38b5]672        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]673        #call_details.show()
674        #print("pairs", pairs)
[ce99754]675        #for k, p in enumerate(self._model_info.parameters.call_parameters):
676        #    print(k, p.name, *pairs[k])
[4edec6f]677        #print("params", self.params)
678        #print("values", values)
679        #print("is_mag", is_magnetic)
[6a0d6aa]680        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]681                            magnetic=is_magnetic)
[ce99754]682        #print("result", result)
[bf8c271]683        self._intermediate_results = getattr(calculator, 'results', None)
[a738209]684        calculator.release()
[9f37726]685        self._model.release()
[ce27e21]686        return result
687
688    def calculate_ER(self):
[fa5fd8d]689        # type: () -> float
[ce27e21]690        """
691        Calculate the effective radius for P(q)*S(q)
692
693        :return: the value of the effective radius
694        """
[4bfd277]695        if self._model_info.ER is None:
[ce27e21]696            return 1.0
697        else:
[4bfd277]698            value, weight = self._dispersion_mesh()
699            fv = self._model_info.ER(*value)
[9404dd3]700            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]701            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]702
703    def calculate_VR(self):
[fa5fd8d]704        # type: () -> float
[ce27e21]705        """
706        Calculate the volf ratio for P(q)*S(q)
707
708        :return: the value of the volf ratio
709        """
[4bfd277]710        if self._model_info.VR is None:
[ce27e21]711            return 1.0
712        else:
[4bfd277]713            value, weight = self._dispersion_mesh()
714            whole, part = self._model_info.VR(*value)
715            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]716
717    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]718        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]719        """
720        Set the dispersion object for a model parameter
721
722        :param parameter: name of the parameter [string]
723        :param dispersion: dispersion object of type Dispersion
724        """
[fa800e72]725        if parameter in self.params:
[1780d59]726            # TODO: Store the disperser object directly in the model.
[56b2687]727            # The current method of relying on the sasview GUI to
[fa800e72]728            # remember them is kind of funky.
[1780d59]729            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]730            # (1) Could create a sasview model that has not yet been
[1780d59]731            # converted, assign the disperser to one of its polydisperse
732            # parameters, then retrieve the disperser parameters from the
[9c1a59c]733            # sasview model.
734            # (2) Could write a disperser parameter retriever in sasview.
735            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]736            # For now, rely on the fact that the sasview only ever uses
737            # new dispersers in the set_dispersion call and create a new
738            # one instead of trying to assign parameters.
[ce27e21]739            self.dispersion[parameter] = dispersion.get_pars()
740        else:
741            raise ValueError("%r is not a dispersity or orientation parameter")
742
[aa4946b]743    def _dispersion_mesh(self):
[fa5fd8d]744        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]745        """
746        Create a mesh grid of dispersion parameters and weights.
747
748        Returns [p1,p2,...],w where pj is a vector of values for parameter j
749        and w is a vector containing the products for weights for each
750        parameter set in the vector.
751        """
[4bfd277]752        pars = [self._get_weights(p)
753                for p in self._model_info.parameters.call_parameters
754                if p.type == 'volume']
[9eb3632]755        return dispersion_mesh(self._model_info, pars)
[ce27e21]756
757    def _get_weights(self, par):
[fa5fd8d]758        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]759        """
[fb5914f]760        Return dispersion weights for parameter
[de0c4ba]761        """
[fa5fd8d]762        if par.name not in self.params:
763            if par.name == self.multiplicity_info.control:
[32f87a5]764                return self.multiplicity, [self.multiplicity], [1.0]
[fa5fd8d]765            else:
[17db833]766                # For hidden parameters use default values.  This sets
767                # scale=1 and background=0 for structure factors
768                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
769                return default, [default], [1.0]
[fa5fd8d]770        elif par.polydisperse:
[32f87a5]771            value = self.params[par.name]
[fb5914f]772            dis = self.dispersion[par.name]
[9c1a59c]773            if dis['type'] == 'array':
[32f87a5]774                dispersity, weight = dis['values'], dis['weights']
[9c1a59c]775            else:
[32f87a5]776                dispersity, weight = weights.get_weights(
[9c1a59c]777                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
[32f87a5]778                    value, par.limits, par.relative_pd)
779            return value, dispersity, weight
[fb5914f]780        else:
[32f87a5]781            value = self.params[par.name]
[ce99754]782            return value, [value], [1.0]
[ce27e21]783
[749a7d4]784def test_cylinder():
[fa5fd8d]785    # type: () -> float
[4d76711]786    """
[749a7d4]787    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]788    """
789    Cylinder = _make_standard_model('cylinder')
[fb5914f]790    cylinder = Cylinder()
[b32dafd]791    return cylinder.evalDistribution([0.1, 0.1])
[de97440]792
[8f93522]793def test_structure_factor():
794    # type: () -> float
795    """
[749a7d4]796    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]797    """
798    Model = _make_standard_model('hardsphere')
799    model = Model()
[17db833]800    value2d = model.evalDistribution([0.1, 0.1])
801    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
802    #print("hardsphere", value1d, value2d)
803    if np.isnan(value1d) or np.isnan(value2d):
804        raise ValueError("hardsphere returns nan")
[8f93522]805
[ce99754]806def test_product():
807    # type: () -> float
808    """
809    Test that 2-D hardsphere model runs and doesn't produce NaN.
810    """
811    S = _make_standard_model('hayter_msa')()
812    P = _make_standard_model('cylinder')()
813    model = MultiplicationModel(P, S)
814    value = model.evalDistribution([0.1, 0.1])
815    if np.isnan(value):
816        raise ValueError("cylinder*hatyer_msa returns null")
817
[04045f4]818def test_rpa():
819    # type: () -> float
820    """
[749a7d4]821    Test that the 2-D RPA model runs
[04045f4]822    """
823    RPA = _make_standard_model('rpa')
824    rpa = RPA(3)
[b32dafd]825    return rpa.evalDistribution([0.1, 0.1])
[04045f4]826
[749a7d4]827def test_empty_distribution():
828    # type: () -> None
829    """
830    Make sure that sasmodels returns NaN when there are no polydispersity points
831    """
832    Cylinder = _make_standard_model('cylinder')
833    cylinder = Cylinder()
834    cylinder.setParam('radius', -1.0)
835    cylinder.setParam('background', 0.)
836    Iq = cylinder.evalDistribution(np.asarray([0.1]))
837    assert np.isnan(Iq[0]), "empty distribution fails"
[4d76711]838
839def test_model_list():
[fa5fd8d]840    # type: () -> None
[4d76711]841    """
[749a7d4]842    Make sure that all models build as sasview models
[4d76711]843    """
844    from .exception import annotate_exception
845    for name in core.list_models():
846        try:
847            _make_standard_model(name)
848        except:
849            annotate_exception("when loading "+name)
850            raise
851
[c95dfc63]852def test_old_name():
853    # type: () -> None
854    """
855    Load and run cylinder model from sas.models.CylinderModel
856    """
857    if not SUPPORT_OLD_STYLE_PLUGINS:
858        return
859    try:
860        # if sasview is not on the path then don't try to test it
861        import sas
862    except ImportError:
863        return
864    load_standard_models()
865    from sas.models.CylinderModel import CylinderModel
866    CylinderModel().evalDistribution([0.1, 0.1])
867
[fb5914f]868if __name__ == "__main__":
[749a7d4]869    print("cylinder(0.1,0.1)=%g"%test_cylinder())
[ce99754]870    #test_product()
[17db833]871    #test_structure_factor()
872    #print("rpa:", test_rpa())
[749a7d4]873    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.