source: sasmodels/sasmodels/sasview_model.py @ 2773c66

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since 2773c66 was 2773c66, checked in by Torin Cooper-Bennun <torin.cooper-bennun@…>, 6 years ago

Merge branch 'beta_approx' into beta_approx_new_R_eff

  • Property mode set to 100644
File size: 32.2 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[9f8ade1]18try:
19    import _thread as thread
20except ImportError:
21    import thread
[ce27e21]22
[7ae2b7f]23import numpy as np  # type: ignore
[ce27e21]24
[aa4946b]25from . import core
[4d76711]26from . import custom
[a80e64c]27from . import product
[72a081d]28from . import generate
[fb5914f]29from . import weights
[6d6508e]30from . import modelinfo
[bde38b5]31from .details import make_kernel_args, dispersion_mesh
[ff7119b]32
[2d81cfe]33# pylint: disable=unused-import
[fa5fd8d]34try:
[2d81cfe]35    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
36                        List, Optional, Union, Callable)
[fa5fd8d]37    from .modelinfo import ModelInfo, Parameter
38    from .kernel import KernelModel
39    MultiplicityInfoType = NamedTuple(
[a9bc435]40        'MultiplicityInfo',
[fa5fd8d]41        [("number", int), ("control", str), ("choices", List[str]),
42         ("x_axis_label", str)])
[60f03de]43    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]44except ImportError:
45    pass
[2d81cfe]46# pylint: enable=unused-import
[fa5fd8d]47
[724257c]48logger = logging.getLogger(__name__)
49
[a38b065]50calculation_lock = thread.allocate_lock()
51
[724257c]52#: True if pre-existing plugins, with the old names and parameters, should
53#: continue to be supported.
[c95dfc63]54SUPPORT_OLD_STYLE_PLUGINS = True
55
[fa5fd8d]56# TODO: separate x_axis_label from multiplicity info
57MultiplicityInfo = collections.namedtuple(
58    'MultiplicityInfo',
59    ["number", "control", "choices", "x_axis_label"],
60)
61
[724257c]62#: set of defined models (standard and custom)
63MODELS = {}  # type: Dict[str, SasviewModelType]
64#: custom model {path: model} mapping so we can check timestamps
65MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
66
[92d38285]67def find_model(modelname):
[b32dafd]68    # type: (str) -> SasviewModelType
69    """
70    Find a model by name.  If the model name ends in py, try loading it from
71    custom models, otherwise look for it in the list of builtin models.
72    """
[92d38285]73    # TODO: used by sum/product model to load an existing model
74    # TODO: doesn't handle custom models properly
75    if modelname.endswith('.py'):
76        return load_custom_model(modelname)
77    elif modelname in MODELS:
78        return MODELS[modelname]
79    else:
80        raise ValueError("unknown model %r"%modelname)
81
[56b2687]82
[fa5fd8d]83# TODO: figure out how to say that the return type is a subclass
[4d76711]84def load_standard_models():
[60f03de]85    # type: () -> List[SasviewModelType]
[4d76711]86    """
87    Load and return the list of predefined models.
88
89    If there is an error loading a model, then a traceback is logged and the
90    model is not returned.
91    """
92    for name in core.list_models():
93        try:
[92d38285]94            MODELS[name] = _make_standard_model(name)
[ee8f734]95        except Exception:
[724257c]96            logger.error(traceback.format_exc())
[c95dfc63]97    if SUPPORT_OLD_STYLE_PLUGINS:
98        _register_old_models()
99
[724257c]100    return list(MODELS.values())
[de97440]101
[4d76711]102
103def load_custom_model(path):
[60f03de]104    # type: (str) -> SasviewModelType
[4d76711]105    """
106    Load a custom model given the model path.
[ff7119b]107    """
[724257c]108    model = MODEL_BY_PATH.get(path, None)
109    if model is not None and model.timestamp == getmtime(path):
110        #logger.info("Model already loaded %s", path)
111        return model
112
113    #logger.info("Loading model %s", path)
[4d76711]114    kernel_module = custom.load_custom_kernel_module(path)
[724257c]115    if hasattr(kernel_module, 'Model'):
[92d38285]116        model = kernel_module.Model
[9457498]117        # Old style models do not set the name in the class attributes, so
118        # set it here; this name will be overridden when the object is created
119        # with an instance variable that has the same value.
120        if model.name == "":
121            model.name = splitext(basename(path))[0]
[20a70bc]122        if not hasattr(model, 'filename'):
[724257c]123            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]124        if not hasattr(model, 'id'):
125            model.id = splitext(basename(model.filename))[0]
[724257c]126    else:
[56b2687]127        model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]128        model = make_model_from_info(model_info)
[724257c]129    model.timestamp = getmtime(path)
[ed10b57]130
[2f2c70c]131    # If a model name already exists and we are loading a different model,
132    # use the model file name as the model name.
133    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
134        _previous_name = model.name
135        model.name = model.id
[bf8c271]136
[2f2c70c]137        # If the new model name is still in the model list (for instance,
138        # if we put a cylinder.py in our plug-in directory), then append
139        # an identifier.
140        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
141            model.name = model.id + '_user'
[724257c]142        logger.info("Model %s already exists: using %s [%s]",
143                    _previous_name, model.name, model.filename)
[ed10b57]144
[92d38285]145    MODELS[model.name] = model
[724257c]146    MODEL_BY_PATH[path] = model
[92d38285]147    return model
[4d76711]148
[87985ca]149
[bcdd6c9]150def make_model_from_info(model_info):
151    # type: (ModelInfo) -> SasviewModelType
152    """
153    Convert *model_info* into a SasView model wrapper.
154    """
155    def __init__(self, multiplicity=None):
156        SasviewModel.__init__(self, multiplicity=multiplicity)
157    attrs = _generate_model_attributes(model_info)
158    attrs['__init__'] = __init__
159    attrs['filename'] = model_info.filename
160    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
161    return ConstructedModel
162
163
[4d76711]164def _make_standard_model(name):
[60f03de]165    # type: (str) -> SasviewModelType
[ff7119b]166    """
[4d76711]167    Load the sasview model defined by *name*.
[72a081d]168
[4d76711]169    *name* can be a standard model name or a path to a custom model.
[87985ca]170
[4d76711]171    Returns a class that can be used directly as a sasview model.
[ff7119b]172    """
[4d76711]173    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]174    model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]175    return make_model_from_info(model_info)
[72a081d]176
177
[724257c]178def _register_old_models():
179    # type: () -> None
180    """
181    Place the new models into sasview under the old names.
182
183    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
184    is available to the plugin modules.
185    """
186    import sys
187    import sas   # needed in order to set sas.models
188    import sas.sascalc.fit
189    sys.modules['sas.models'] = sas.sascalc.fit
190    sas.models = sas.sascalc.fit
191    import sas.models
192    from sasmodels.conversion_table import CONVERSION_TABLE
[e65c3ba]193
[724257c]194    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
195        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
196        new_name = new_name.split(':')[0]
197        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
198        module_attrs = {old_name: find_model(new_name)}
199        ConstructedModule = type(old_name, (), module_attrs)
200        old_path = 'sas.models.' + old_name
201        setattr(sas.models, old_path, ConstructedModule)
202        sys.modules[old_path] = ConstructedModule
203
204
[a80e64c]205def MultiplicationModel(form_factor, structure_factor):
206    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
[e65c3ba]207    """
208    Returns a constructed product model from form_factor and structure_factor.
209    """
[a80e64c]210    model_info = product.make_product_info(form_factor._model_info,
211                                           structure_factor._model_info)
[bcdd6c9]212    ConstructedModel = make_model_from_info(model_info)
[a06af5d]213    return ConstructedModel(form_factor.multiplicity)
[a80e64c]214
[ce27e21]215
[fa5fd8d]216def _generate_model_attributes(model_info):
217    # type: (ModelInfo) -> Dict[str, Any]
218    """
219    Generate the class attributes for the model.
220
221    This should include all the information necessary to query the model
222    details so that you do not need to instantiate a model to query it.
223
224    All the attributes should be immutable to avoid accidents.
225    """
226
227    # TODO: allow model to override axis labels input/output name/unit
228
[a18c5b3]229    # Process multiplicity
[fa5fd8d]230    non_fittable = []  # type: List[str]
[04045f4]231    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
232    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]233    for p in model_info.parameters.kernel_parameters:
[04045f4]234        if p.name == model_info.control:
[fa5fd8d]235            non_fittable.append(p.name)
[04045f4]236            variants = MultiplicityInfo(
[ce176ca]237                len(p.choices) if p.choices else int(p.limits[1]),
238                p.name, p.choices, xlabel
[fa5fd8d]239            )
240            break
241
[50ec515]242    # Only a single drop-down list parameter available
243    fun_list = []
244    for p in model_info.parameters.kernel_parameters:
245        if p.choices:
246            fun_list = p.choices
247            if p.length > 1:
248                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
249            break
250
[a18c5b3]251    # Organize parameter sets
[fa5fd8d]252    orientation_params = []
253    magnetic_params = []
254    fixed = []
[85fe7f8]255    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]256        if p.type == 'orientation':
257            orientation_params.append(p.name)
258            orientation_params.append(p.name+".width")
259            fixed.append(p.name+".width")
[32e3c9b]260        elif p.type == 'magnetic':
[fa5fd8d]261            orientation_params.append(p.name)
262            magnetic_params.append(p.name)
263            fixed.append(p.name+".width")
[a18c5b3]264
[32e3c9b]265
[a18c5b3]266    # Build class dictionary
267    attrs = {}  # type: Dict[str, Any]
268    attrs['_model_info'] = model_info
269    attrs['name'] = model_info.name
270    attrs['id'] = model_info.id
271    attrs['description'] = model_info.description
272    attrs['category'] = model_info.category
273    attrs['is_structure_factor'] = model_info.structure_factor
[6e7ba14]274    attrs['is_form_factor'] = model_info.effective_radius_type is not None
[a18c5b3]275    attrs['is_multiplicity_model'] = variants[0] > 1
276    attrs['multiplicity_info'] = variants
[fa5fd8d]277    attrs['orientation_params'] = tuple(orientation_params)
278    attrs['magnetic_params'] = tuple(magnetic_params)
279    attrs['fixed'] = tuple(fixed)
280    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]281    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]282
283    return attrs
[4d76711]284
[ce27e21]285class SasviewModel(object):
286    """
287    Sasview wrapper for opencl/ctypes model.
288    """
[fa5fd8d]289    # Model parameters for the specific model are set in the class constructor
290    # via the _generate_model_attributes function, which subclasses
291    # SasviewModel.  They are included here for typing and documentation
292    # purposes.
293    _model = None       # type: KernelModel
294    _model_info = None  # type: ModelInfo
295    #: load/save name for the model
296    id = None           # type: str
297    #: display name for the model
298    name = None         # type: str
299    #: short model description
300    description = None  # type: str
301    #: default model category
302    category = None     # type: str
303
304    #: names of the orientation parameters in the order they appear
[724257c]305    orientation_params = None # type: List[str]
[fa5fd8d]306    #: names of the magnetic parameters in the order they appear
[724257c]307    magnetic_params = None    # type: List[str]
[fa5fd8d]308    #: names of the fittable parameters
[724257c]309    fixed = None              # type: List[str]
[fa5fd8d]310    # TODO: the attribute fixed is ill-named
311
312    # Axis labels
313    input_name = "Q"
314    input_unit = "A^{-1}"
315    output_name = "Intensity"
316    output_unit = "cm^{-1}"
317
318    #: default cutoff for polydispersity
319    cutoff = 1e-5
320
321    # Note: Use non-mutable values for class attributes to avoid errors
322    #: parameters that are not fitted
323    non_fittable = ()        # type: Sequence[str]
324
325    #: True if model should appear as a structure factor
326    is_structure_factor = False
327    #: True if model should appear as a form factor
328    is_form_factor = False
329    #: True if model has multiplicity
330    is_multiplicity_model = False
[1f35235]331    #: Multiplicity information
[fa5fd8d]332    multiplicity_info = None # type: MultiplicityInfoType
333
334    # Per-instance variables
335    #: parameter {name: value} mapping
336    params = None      # type: Dict[str, float]
337    #: values for dispersion width, npts, nsigmas and type
338    dispersion = None  # type: Dict[str, Any]
339    #: units and limits for each parameter
[60f03de]340    details = None     # type: Dict[str, Sequence[Any]]
341    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]342    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]343    multiplicity = None     # type: Optional[int]
[04dc697]344    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
345    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]346
347    def __init__(self, multiplicity=None):
[04dc697]348        # type: (Optional[int]) -> None
[2622b3f]349
[04045f4]350        # TODO: _persistency_dict to persistency_dict throughout sasview
351        # TODO: refactor multiplicity to encompass variants
352        # TODO: dispersion should be a class
[fa5fd8d]353        # TODO: refactor multiplicity info
354        # TODO: separate profile view from multiplicity
355        # The button label, x and y axis labels and scale need to be under
356        # the control of the model, not the fit page.  Maximum flexibility,
357        # the fit page would supply the canvas and the profile could plot
358        # how it wants, but this assumes matplotlib.  Next level is that
359        # we provide some sort of data description including title, labels
360        # and lines to plot.
361
[1f35235]362        # Get the list of hidden parameters given the multiplicity
[04045f4]363        # Don't include multiplicity in the list of parameters
[fa5fd8d]364        self.multiplicity = multiplicity
[04045f4]365        if multiplicity is not None:
366            hidden = self._model_info.get_hidden_parameters(multiplicity)
367            hidden |= set([self.multiplicity_info.control])
368        else:
369            hidden = set()
[8f93522]370        if self._model_info.structure_factor:
371            hidden.add('scale')
372            hidden.add('background')
373            self._model_info.parameters.defaults['background'] = 0.
[04045f4]374
[04dc697]375        self._persistency_dict = {}
[fa5fd8d]376        self.params = collections.OrderedDict()
[b3a85cd]377        self.dispersion = collections.OrderedDict()
[fa5fd8d]378        self.details = {}
[8977226]379        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]380            if p.name in hidden:
[fa5fd8d]381                continue
[fcd7bbd]382            self.params[p.name] = p.default
[aa44a6a]383            if p.limits and type(p.limits) is list and len(p.limits) > 1:
384                self.details[p.id] = [p.units if not p.choices else p.choices, p.limits[0], p.limits[1]]
385            else:
386                self.details[p.id] = [p.units if not p.choices else p.choices, None, None]
[fb5914f]387            if p.polydisperse:
[fa5fd8d]388                self.details[p.id+".width"] = [
389                    "", 0.0, 1.0 if p.relative_pd else np.inf
390                ]
[fb5914f]391                self.dispersion[p.name] = {
392                    'width': 0,
393                    'npts': 35,
394                    'nsigmas': 3,
395                    'type': 'gaussian',
396                }
[ce27e21]397
[de97440]398    def __get_state__(self):
[fa5fd8d]399        # type: () -> Dict[str, Any]
[de97440]400        state = self.__dict__.copy()
[4d76711]401        state.pop('_model')
[de97440]402        # May need to reload model info on set state since it has pointers
403        # to python implementations of Iq, etc.
404        #state.pop('_model_info')
405        return state
406
407    def __set_state__(self, state):
[fa5fd8d]408        # type: (Dict[str, Any]) -> None
[de97440]409        self.__dict__ = state
[fb5914f]410        self._model = None
[de97440]411
[ce27e21]412    def __str__(self):
[fa5fd8d]413        # type: () -> str
[ce27e21]414        """
415        :return: string representation
416        """
417        return self.name
418
419    def is_fittable(self, par_name):
[fa5fd8d]420        # type: (str) -> bool
[ce27e21]421        """
422        Check if a given parameter is fittable or not
423
424        :param par_name: the parameter name to check
425        """
[e758662]426        return par_name in self.fixed
[ce27e21]427        #For the future
428        #return self.params[str(par_name)].is_fittable()
429
430
431    def getProfile(self):
[fa5fd8d]432        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]433        """
434        Get SLD profile
435
436        : return: (z, beta) where z is a list of depth of the transition points
437                beta is a list of the corresponding SLD values
438        """
[745b7bb]439        args = {} # type: Dict[str, Any]
[fa5fd8d]440        for p in self._model_info.parameters.kernel_parameters:
441            if p.id == self.multiplicity_info.control:
[745b7bb]442                value = float(self.multiplicity)
[fa5fd8d]443            elif p.length == 1:
[745b7bb]444                value = self.params.get(p.id, np.NaN)
[fa5fd8d]445            else:
[745b7bb]446                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]447                                  for k in range(1, p.length+1)])
[745b7bb]448            args[p.id] = value
449
[e7fe459]450        x, y = self._model_info.profile(**args)
451        return x, 1e-6*y
[ce27e21]452
453    def setParam(self, name, value):
[fa5fd8d]454        # type: (str, float) -> None
[ce27e21]455        """
456        Set the value of a model parameter
457
458        :param name: name of the parameter
459        :param value: value of the parameter
460
461        """
462        # Look for dispersion parameters
463        toks = name.split('.')
[de0c4ba]464        if len(toks) == 2:
[ce27e21]465            for item in self.dispersion.keys():
[e758662]466                if item == toks[0]:
[ce27e21]467                    for par in self.dispersion[item]:
[e758662]468                        if par == toks[1]:
[ce27e21]469                            self.dispersion[item][par] = value
470                            return
471        else:
472            # Look for standard parameter
473            for item in self.params.keys():
[e758662]474                if item == name:
[ce27e21]475                    self.params[item] = value
476                    return
477
[63b32bb]478        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]479
480    def getParam(self, name):
[fa5fd8d]481        # type: (str) -> float
[ce27e21]482        """
483        Set the value of a model parameter
484
485        :param name: name of the parameter
486
487        """
488        # Look for dispersion parameters
489        toks = name.split('.')
[de0c4ba]490        if len(toks) == 2:
[ce27e21]491            for item in self.dispersion.keys():
[e758662]492                if item == toks[0]:
[ce27e21]493                    for par in self.dispersion[item]:
[e758662]494                        if par == toks[1]:
[ce27e21]495                            return self.dispersion[item][par]
496        else:
497            # Look for standard parameter
498            for item in self.params.keys():
[e758662]499                if item == name:
[ce27e21]500                    return self.params[item]
501
[63b32bb]502        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]503
504    def getParamList(self):
[04dc697]505        # type: () -> Sequence[str]
[ce27e21]506        """
507        Return a list of all available parameters for the model
508        """
[04dc697]509        param_list = list(self.params.keys())
[ce27e21]510        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]511        param_list.extend(self.getDispParamList())
512        return param_list
[ce27e21]513
514    def getDispParamList(self):
[04dc697]515        # type: () -> Sequence[str]
[ce27e21]516        """
[fb5914f]517        Return a list of polydispersity parameters for the model
[ce27e21]518        """
[1780d59]519        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]520        ret = ['%s.%s' % (p_name, ext)
521               for p_name in self.dispersion.keys()
522               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]523        #print(ret)
[1780d59]524        return ret
[ce27e21]525
526    def clone(self):
[04dc697]527        # type: () -> "SasviewModel"
[ce27e21]528        """ Return a identical copy of self """
529        return deepcopy(self)
530
531    def run(self, x=0.0):
[fa5fd8d]532        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]533        """
534        Evaluate the model
535
536        :param x: input q, or [q,phi]
537
538        :return: scattering function P(q)
539
540        **DEPRECATED**: use calculate_Iq instead
541        """
[de0c4ba]542        if isinstance(x, (list, tuple)):
[3c56da87]543            # pylint: disable=unpacking-non-sequence
[ce27e21]544            q, phi = x
[84f2962]545            result, _ = self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])
546            return result[0]
[ce27e21]547        else:
[84f2962]548            result, _ = self.calculate_Iq([x])
549            return result[0]
[ce27e21]550
551
552    def runXY(self, x=0.0):
[fa5fd8d]553        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]554        """
555        Evaluate the model in cartesian coordinates
556
557        :param x: input q, or [qx, qy]
558
559        :return: scattering function P(q)
560
561        **DEPRECATED**: use calculate_Iq instead
562        """
[de0c4ba]563        if isinstance(x, (list, tuple)):
[84f2962]564            result, _ = self.calculate_Iq([x[0]], [x[1]])
565            return result[0]
[ce27e21]566        else:
[84f2962]567            result, _ = self.calculate_Iq([x])
568            return result[0]
[ce27e21]569
570    def evalDistribution(self, qdist):
[04dc697]571        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]572        r"""
[ce27e21]573        Evaluate a distribution of q-values.
574
[d138d43]575        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]576
[d138d43]577        * For 1D, a numpy array is expected as input
[ce27e21]578
[d138d43]579        ::
[ce27e21]580
[d138d43]581            evalDistribution(q)
[ce27e21]582
[d138d43]583          where *q* is a numpy array.
[ce27e21]584
[d138d43]585        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]586
[d138d43]587        ::
[ce27e21]588
[d138d43]589              qx = [ qx[0], qx[1], qx[2], ....]
590              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]591
[d138d43]592        If the model is 1D only, then
[ce27e21]593
[d138d43]594        .. math::
[ce27e21]595
[d138d43]596            q = \sqrt{q_x^2+q_y^2}
[ce27e21]597
598        """
[de0c4ba]599        if isinstance(qdist, (list, tuple)):
[ce27e21]600            # Check whether we have a list of ndarrays [qx,qy]
601            qx, qy = qdist
[84f2962]602            result, _ = self.calculate_Iq(qx, qy)
603            return result
[ce27e21]604
605        elif isinstance(qdist, np.ndarray):
606            # We have a simple 1D distribution of q-values
[84f2962]607            result, _ = self.calculate_Iq(qdist)
608            return result
[ce27e21]609
610        else:
[3c56da87]611            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
612                            % type(qdist))
[ce27e21]613
[9dcb21d]614    def calc_composition_models(self, qx):
[64614ad]615        """
[9dcb21d]616        returns parts of the composition model or None if not a composition
617        model.
[64614ad]618        """
[946c8d27]619        # TODO: have calculate_Iq return the intermediates.
620        #
621        # The current interface causes calculate_Iq() to be called twice,
622        # once to get the combined result and again to get the intermediate
623        # results.  This is necessary for now.
624        # Long term, the solution is to change the interface to calculate_Iq
625        # so that it returns a results object containing all the bits:
[9644b5a]626        #     the A, B, C, ... of the composition model (and any subcomponents?)
[d32de68]627        #     the P and S of the product model
[946c8d27]628        #     the combined model before resolution smearing,
629        #     the sasmodel before sesans conversion,
630        #     the oriented 2D model used to fit oriented usans data,
631        #     the final I(q),
632        #     ...
[9644b5a]633        #
[946c8d27]634        # Have the model calculator add all of these blindly to the data
635        # tree, and update the graphs which contain them.  The fitter
636        # needs to be updated to use the I(q) value only, ignoring the rest.
637        #
638        # The simple fix of returning the existing intermediate results
639        # will not work for a couple of reasons: (1) another thread may
640        # sneak in to compute its own results before calc_composition_models
641        # is called, and (2) calculate_Iq is currently called three times:
642        # once with q, once with q values before qmin and once with q values
643        # after q max.  Both of these should be addressed before
644        # replacing this code.
[9644b5a]645        composition = self._model_info.composition
646        if composition and composition[0] == 'product': # only P*S for now
647            with calculation_lock:
[84f2962]648                _, lazy_results = self._calculate_Iq(qx)
649                # for compatibility with sasview 4.x
650                results = lazy_results()
651                pq_data = results.get("P(Q)")
652                sq_data = results.get("S(Q)")
653                return pq_data, sq_data
[9644b5a]654        else:
655            return None
[bf8c271]656
[84f2962]657    def calculate_Iq(self,
658                     qx,     # type: Sequence[float]
659                     qy=None # type: Optional[Sequence[float]]
660                     ):
661        # type: (...) -> Tuple[np.ndarray, Callable[[], collections.OrderedDict[str, np.ndarray]]]
[ff7119b]662        """
663        Calculate Iq for one set of q with the current parameters.
664
665        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
666
667        This should NOT be used for fitting since it copies the *q* vectors
668        to the card for each evaluation.
[84f2962]669
670        The returned tuple contains the scattering intensity followed by a
671        callable which returns a dictionary of intermediate data from
672        ProductKernel.
[ff7119b]673        """
[a38b065]674        ## uncomment the following when trying to debug the uncoordinated calls
675        ## to calculate_Iq
676        #if calculation_lock.locked():
[724257c]677        #    logger.info("calculation waiting for another thread to complete")
678        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]679
680        with calculation_lock:
681            return self._calculate_Iq(qx, qy)
682
683    def _calculate_Iq(self, qx, qy=None):
[fb5914f]684        if self._model is None:
[d2bb604]685            self._model = core.build_model(self._model_info)
[fa5fd8d]686        if qy is not None:
687            q_vectors = [np.asarray(qx), np.asarray(qy)]
688        else:
689            q_vectors = [np.asarray(qx)]
[a738209]690        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]691        parameters = self._model_info.parameters
692        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]693        #weights.plot_weights(self._model_info, pairs)
[bde38b5]694        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]695        #call_details.show()
[05df1de]696        #print("================ parameters ==================")
697        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
[ce99754]698        #for k, p in enumerate(self._model_info.parameters.call_parameters):
699        #    print(k, p.name, *pairs[k])
[4edec6f]700        #print("params", self.params)
701        #print("values", values)
702        #print("is_mag", is_magnetic)
[6a0d6aa]703        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]704                            magnetic=is_magnetic)
[84f2962]705        lazy_results = getattr(calculator, 'results',
706                               lambda: collections.OrderedDict())
[ce99754]707        #print("result", result)
[84f2962]708
[a738209]709        calculator.release()
[d533590]710        #self._model.release()
[84f2962]711
712        return result, lazy_results
[ce27e21]713
714    def calculate_ER(self):
[fa5fd8d]715        # type: () -> float
[ce27e21]716        """
717        Calculate the effective radius for P(q)*S(q)
718
719        :return: the value of the effective radius
720        """
[4bfd277]721        if self._model_info.ER is None:
[ce27e21]722            return 1.0
723        else:
[4bfd277]724            value, weight = self._dispersion_mesh()
725            fv = self._model_info.ER(*value)
[9404dd3]726            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]727            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]728
729    def calculate_VR(self):
[fa5fd8d]730        # type: () -> float
[ce27e21]731        """
732        Calculate the volf ratio for P(q)*S(q)
733
734        :return: the value of the volf ratio
735        """
[4bfd277]736        if self._model_info.VR is None:
[ce27e21]737            return 1.0
738        else:
[4bfd277]739            value, weight = self._dispersion_mesh()
740            whole, part = self._model_info.VR(*value)
741            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]742
743    def set_dispersion(self, parameter, dispersion):
[7c3fb15]744        # type: (str, weights.Dispersion) -> None
[ce27e21]745        """
746        Set the dispersion object for a model parameter
747
748        :param parameter: name of the parameter [string]
749        :param dispersion: dispersion object of type Dispersion
750        """
[fa800e72]751        if parameter in self.params:
[1780d59]752            # TODO: Store the disperser object directly in the model.
[56b2687]753            # The current method of relying on the sasview GUI to
[fa800e72]754            # remember them is kind of funky.
[1780d59]755            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]756            # (1) Could create a sasview model that has not yet been
[1780d59]757            # converted, assign the disperser to one of its polydisperse
758            # parameters, then retrieve the disperser parameters from the
[9c1a59c]759            # sasview model.
760            # (2) Could write a disperser parameter retriever in sasview.
761            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]762            # For now, rely on the fact that the sasview only ever uses
763            # new dispersers in the set_dispersion call and create a new
764            # one instead of trying to assign parameters.
[ce27e21]765            self.dispersion[parameter] = dispersion.get_pars()
766        else:
[7c3fb15]767            raise ValueError("%r is not a dispersity or orientation parameter"
768                             % parameter)
[ce27e21]769
[aa4946b]770    def _dispersion_mesh(self):
[fa5fd8d]771        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]772        """
773        Create a mesh grid of dispersion parameters and weights.
774
775        Returns [p1,p2,...],w where pj is a vector of values for parameter j
776        and w is a vector containing the products for weights for each
777        parameter set in the vector.
778        """
[4bfd277]779        pars = [self._get_weights(p)
780                for p in self._model_info.parameters.call_parameters
781                if p.type == 'volume']
[9eb3632]782        return dispersion_mesh(self._model_info, pars)
[ce27e21]783
784    def _get_weights(self, par):
[fa5fd8d]785        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]786        """
[fb5914f]787        Return dispersion weights for parameter
[de0c4ba]788        """
[fa5fd8d]789        if par.name not in self.params:
790            if par.name == self.multiplicity_info.control:
[32f87a5]791                return self.multiplicity, [self.multiplicity], [1.0]
[fa5fd8d]792            else:
[17db833]793                # For hidden parameters use default values.  This sets
794                # scale=1 and background=0 for structure factors
795                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
796                return default, [default], [1.0]
[fa5fd8d]797        elif par.polydisperse:
[32f87a5]798            value = self.params[par.name]
[fb5914f]799            dis = self.dispersion[par.name]
[9c1a59c]800            if dis['type'] == 'array':
[32f87a5]801                dispersity, weight = dis['values'], dis['weights']
[9c1a59c]802            else:
[32f87a5]803                dispersity, weight = weights.get_weights(
[9c1a59c]804                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
[32f87a5]805                    value, par.limits, par.relative_pd)
806            return value, dispersity, weight
[fb5914f]807        else:
[32f87a5]808            value = self.params[par.name]
[ce99754]809            return value, [value], [1.0]
[ce27e21]810
[749a7d4]811def test_cylinder():
[fa5fd8d]812    # type: () -> float
[4d76711]813    """
[749a7d4]814    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]815    """
816    Cylinder = _make_standard_model('cylinder')
[fb5914f]817    cylinder = Cylinder()
[b32dafd]818    return cylinder.evalDistribution([0.1, 0.1])
[de97440]819
[8f93522]820def test_structure_factor():
821    # type: () -> float
822    """
[749a7d4]823    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]824    """
825    Model = _make_standard_model('hardsphere')
826    model = Model()
[17db833]827    value2d = model.evalDistribution([0.1, 0.1])
828    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
829    #print("hardsphere", value1d, value2d)
830    if np.isnan(value1d) or np.isnan(value2d):
831        raise ValueError("hardsphere returns nan")
[8f93522]832
[ce99754]833def test_product():
834    # type: () -> float
835    """
836    Test that 2-D hardsphere model runs and doesn't produce NaN.
837    """
838    S = _make_standard_model('hayter_msa')()
839    P = _make_standard_model('cylinder')()
840    model = MultiplicationModel(P, S)
841    value = model.evalDistribution([0.1, 0.1])
842    if np.isnan(value):
843        raise ValueError("cylinder*hatyer_msa returns null")
844
[04045f4]845def test_rpa():
846    # type: () -> float
847    """
[749a7d4]848    Test that the 2-D RPA model runs
[04045f4]849    """
850    RPA = _make_standard_model('rpa')
851    rpa = RPA(3)
[b32dafd]852    return rpa.evalDistribution([0.1, 0.1])
[04045f4]853
[749a7d4]854def test_empty_distribution():
855    # type: () -> None
856    """
857    Make sure that sasmodels returns NaN when there are no polydispersity points
858    """
859    Cylinder = _make_standard_model('cylinder')
860    cylinder = Cylinder()
861    cylinder.setParam('radius', -1.0)
862    cylinder.setParam('background', 0.)
863    Iq = cylinder.evalDistribution(np.asarray([0.1]))
[2d81cfe]864    assert Iq[0] == 0., "empty distribution fails"
[4d76711]865
866def test_model_list():
[fa5fd8d]867    # type: () -> None
[4d76711]868    """
[749a7d4]869    Make sure that all models build as sasview models
[4d76711]870    """
871    from .exception import annotate_exception
872    for name in core.list_models():
873        try:
874            _make_standard_model(name)
875        except:
876            annotate_exception("when loading "+name)
877            raise
878
[c95dfc63]879def test_old_name():
880    # type: () -> None
881    """
[a69d8cd]882    Load and run cylinder model as sas-models-CylinderModel
[c95dfc63]883    """
884    if not SUPPORT_OLD_STYLE_PLUGINS:
885        return
886    try:
887        # if sasview is not on the path then don't try to test it
888        import sas
889    except ImportError:
890        return
891    load_standard_models()
892    from sas.models.CylinderModel import CylinderModel
893    CylinderModel().evalDistribution([0.1, 0.1])
894
[05df1de]895def magnetic_demo():
896    Model = _make_standard_model('sphere')
897    model = Model()
898    model.setParam('M0:sld', 8)
899    q = np.linspace(-0.35, 0.35, 500)
900    qx, qy = np.meshgrid(q, q)
[84f2962]901    result, _ = model.calculate_Iq(qx.flatten(), qy.flatten())
[05df1de]902    result = result.reshape(qx.shape)
903
904    import pylab
905    pylab.imshow(np.log(result + 0.001))
906    pylab.show()
907
[fb5914f]908if __name__ == "__main__":
[749a7d4]909    print("cylinder(0.1,0.1)=%g"%test_cylinder())
[05df1de]910    #magnetic_demo()
[ce99754]911    #test_product()
[17db833]912    #test_structure_factor()
913    #print("rpa:", test_rpa())
[749a7d4]914    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.