source: sasmodels/sasmodels/sasview_model.py @ d321747

core_shell_microgelsmagnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since d321747 was d321747, checked in by Paul Kienzle <pkienzle@…>, 5 years ago

sasview only updates models when code changes; now detects changes to c files

  • Property mode set to 100644
File size: 31.8 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[724257c]17from os.path import basename, splitext, abspath, getmtime
[9f8ade1]18try:
19    import _thread as thread
20except ImportError:
21    import thread
[ce27e21]22
[7ae2b7f]23import numpy as np  # type: ignore
[ce27e21]24
[aa4946b]25from . import core
[4d76711]26from . import custom
[a80e64c]27from . import product
[72a081d]28from . import generate
[fb5914f]29from . import weights
[6d6508e]30from . import modelinfo
[bde38b5]31from .details import make_kernel_args, dispersion_mesh
[ff7119b]32
[2d81cfe]33# pylint: disable=unused-import
[fa5fd8d]34try:
[2d81cfe]35    from typing import (Dict, Mapping, Any, Sequence, Tuple, NamedTuple,
36                        List, Optional, Union, Callable)
[fa5fd8d]37    from .modelinfo import ModelInfo, Parameter
38    from .kernel import KernelModel
39    MultiplicityInfoType = NamedTuple(
[a9bc435]40        'MultiplicityInfo',
[fa5fd8d]41        [("number", int), ("control", str), ("choices", List[str]),
42         ("x_axis_label", str)])
[60f03de]43    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]44except ImportError:
45    pass
[2d81cfe]46# pylint: enable=unused-import
[fa5fd8d]47
[724257c]48logger = logging.getLogger(__name__)
49
[a38b065]50calculation_lock = thread.allocate_lock()
51
[724257c]52#: True if pre-existing plugins, with the old names and parameters, should
53#: continue to be supported.
[c95dfc63]54SUPPORT_OLD_STYLE_PLUGINS = True
55
[fa5fd8d]56# TODO: separate x_axis_label from multiplicity info
57MultiplicityInfo = collections.namedtuple(
58    'MultiplicityInfo',
59    ["number", "control", "choices", "x_axis_label"],
60)
61
[724257c]62#: set of defined models (standard and custom)
63MODELS = {}  # type: Dict[str, SasviewModelType]
[839fd68]64# TODO: remove unused MODEL_BY_PATH cache once sasview no longer references it
[724257c]65#: custom model {path: model} mapping so we can check timestamps
66MODEL_BY_PATH = {}  # type: Dict[str, SasviewModelType]
[d321747]67#: Track modules that we have loaded so we can determine whether the model
68#: has changed since we last reloaded.
69_CACHED_MODULE = {}  # type: Dict[str, "module"]
[724257c]70
[92d38285]71def find_model(modelname):
[b32dafd]72    # type: (str) -> SasviewModelType
73    """
74    Find a model by name.  If the model name ends in py, try loading it from
75    custom models, otherwise look for it in the list of builtin models.
76    """
[92d38285]77    # TODO: used by sum/product model to load an existing model
78    # TODO: doesn't handle custom models properly
79    if modelname.endswith('.py'):
80        return load_custom_model(modelname)
81    elif modelname in MODELS:
82        return MODELS[modelname]
83    else:
84        raise ValueError("unknown model %r"%modelname)
85
[56b2687]86
[fa5fd8d]87# TODO: figure out how to say that the return type is a subclass
[4d76711]88def load_standard_models():
[60f03de]89    # type: () -> List[SasviewModelType]
[4d76711]90    """
91    Load and return the list of predefined models.
92
93    If there is an error loading a model, then a traceback is logged and the
94    model is not returned.
95    """
96    for name in core.list_models():
97        try:
[92d38285]98            MODELS[name] = _make_standard_model(name)
[ee8f734]99        except Exception:
[724257c]100            logger.error(traceback.format_exc())
[c95dfc63]101    if SUPPORT_OLD_STYLE_PLUGINS:
102        _register_old_models()
103
[724257c]104    return list(MODELS.values())
[de97440]105
[4d76711]106
107def load_custom_model(path):
[60f03de]108    # type: (str) -> SasviewModelType
[4d76711]109    """
110    Load a custom model given the model path.
[ff7119b]111    """
[724257c]112    #logger.info("Loading model %s", path)
[d321747]113
114    # Load the kernel module.  This may already be cached by the loader, so
115    # only requires checking the timestamps of the dependents.
[4d76711]116    kernel_module = custom.load_custom_kernel_module(path)
[d321747]117
118    # Check if the module has changed since we last looked.
119    reloaded = kernel_module != _CACHED_MODULE.get(path, None)
120    _CACHED_MODULE[path] = kernel_module
121
122    # Turn the module into a model.  We need to do this in even if the
123    # model has already been loaded so that we can determine the model
124    # name and retrieve it from the MODELS cache.
125    model = getattr(kernel_module, 'Model', None)
126    if model is not None:
[9457498]127        # Old style models do not set the name in the class attributes, so
128        # set it here; this name will be overridden when the object is created
129        # with an instance variable that has the same value.
130        if model.name == "":
131            model.name = splitext(basename(path))[0]
[20a70bc]132        if not hasattr(model, 'filename'):
[724257c]133            model.filename = abspath(kernel_module.__file__).replace('.pyc', '.py')
[e4bf271]134        if not hasattr(model, 'id'):
135            model.id = splitext(basename(model.filename))[0]
[724257c]136    else:
[56b2687]137        model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]138        model = make_model_from_info(model_info)
[ed10b57]139
[2f2c70c]140    # If a model name already exists and we are loading a different model,
141    # use the model file name as the model name.
142    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
143        _previous_name = model.name
144        model.name = model.id
[bf8c271]145
[2f2c70c]146        # If the new model name is still in the model list (for instance,
147        # if we put a cylinder.py in our plug-in directory), then append
148        # an identifier.
149        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
150            model.name = model.id + '_user'
[724257c]151        logger.info("Model %s already exists: using %s [%s]",
152                    _previous_name, model.name, model.filename)
[ed10b57]153
[d321747]154    # Only update the model if the module has changed
155    if reloaded or model.name not in MODELS:
156        MODELS[model.name] = model
157
158    return MODELS[model.name]
[4d76711]159
[87985ca]160
[bcdd6c9]161def make_model_from_info(model_info):
162    # type: (ModelInfo) -> SasviewModelType
163    """
164    Convert *model_info* into a SasView model wrapper.
165    """
166    def __init__(self, multiplicity=None):
167        SasviewModel.__init__(self, multiplicity=multiplicity)
168    attrs = _generate_model_attributes(model_info)
169    attrs['__init__'] = __init__
170    attrs['filename'] = model_info.filename
171    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
172    return ConstructedModel
173
174
[4d76711]175def _make_standard_model(name):
[60f03de]176    # type: (str) -> SasviewModelType
[ff7119b]177    """
[4d76711]178    Load the sasview model defined by *name*.
[72a081d]179
[4d76711]180    *name* can be a standard model name or a path to a custom model.
[87985ca]181
[4d76711]182    Returns a class that can be used directly as a sasview model.
[ff7119b]183    """
[4d76711]184    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]185    model_info = modelinfo.make_model_info(kernel_module)
[bcdd6c9]186    return make_model_from_info(model_info)
[72a081d]187
188
[724257c]189def _register_old_models():
190    # type: () -> None
191    """
192    Place the new models into sasview under the old names.
193
194    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
195    is available to the plugin modules.
196    """
197    import sys
198    import sas   # needed in order to set sas.models
199    import sas.sascalc.fit
200    sys.modules['sas.models'] = sas.sascalc.fit
201    sas.models = sas.sascalc.fit
202    import sas.models
203    from sasmodels.conversion_table import CONVERSION_TABLE
[e65c3ba]204
[724257c]205    for new_name, conversion in CONVERSION_TABLE.get((3, 1, 2), {}).items():
206        # CoreShellEllipsoidModel => core_shell_ellipsoid:1
207        new_name = new_name.split(':')[0]
208        old_name = conversion[0] if len(conversion) < 3 else conversion[2]
209        module_attrs = {old_name: find_model(new_name)}
210        ConstructedModule = type(old_name, (), module_attrs)
211        old_path = 'sas.models.' + old_name
212        setattr(sas.models, old_path, ConstructedModule)
213        sys.modules[old_path] = ConstructedModule
214
215
[a80e64c]216def MultiplicationModel(form_factor, structure_factor):
217    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
[e65c3ba]218    """
219    Returns a constructed product model from form_factor and structure_factor.
220    """
[a80e64c]221    model_info = product.make_product_info(form_factor._model_info,
222                                           structure_factor._model_info)
[bcdd6c9]223    ConstructedModel = make_model_from_info(model_info)
[a06af5d]224    return ConstructedModel(form_factor.multiplicity)
[a80e64c]225
[ce27e21]226
[fa5fd8d]227def _generate_model_attributes(model_info):
228    # type: (ModelInfo) -> Dict[str, Any]
229    """
230    Generate the class attributes for the model.
231
232    This should include all the information necessary to query the model
233    details so that you do not need to instantiate a model to query it.
234
235    All the attributes should be immutable to avoid accidents.
236    """
237
238    # TODO: allow model to override axis labels input/output name/unit
239
[a18c5b3]240    # Process multiplicity
[fa5fd8d]241    non_fittable = []  # type: List[str]
[04045f4]242    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
243    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]244    for p in model_info.parameters.kernel_parameters:
[04045f4]245        if p.name == model_info.control:
[fa5fd8d]246            non_fittable.append(p.name)
[04045f4]247            variants = MultiplicityInfo(
[ce176ca]248                len(p.choices) if p.choices else int(p.limits[1]),
249                p.name, p.choices, xlabel
[fa5fd8d]250            )
251            break
252
[50ec515]253    # Only a single drop-down list parameter available
254    fun_list = []
255    for p in model_info.parameters.kernel_parameters:
256        if p.choices:
257            fun_list = p.choices
258            if p.length > 1:
259                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
260            break
261
[a18c5b3]262    # Organize parameter sets
[fa5fd8d]263    orientation_params = []
264    magnetic_params = []
265    fixed = []
[85fe7f8]266    for p in model_info.parameters.user_parameters({}, is2d=True):
[fa5fd8d]267        if p.type == 'orientation':
268            orientation_params.append(p.name)
269            orientation_params.append(p.name+".width")
270            fixed.append(p.name+".width")
[32e3c9b]271        elif p.type == 'magnetic':
[fa5fd8d]272            orientation_params.append(p.name)
273            magnetic_params.append(p.name)
274            fixed.append(p.name+".width")
[a18c5b3]275
[32e3c9b]276
[a18c5b3]277    # Build class dictionary
278    attrs = {}  # type: Dict[str, Any]
279    attrs['_model_info'] = model_info
280    attrs['name'] = model_info.name
281    attrs['id'] = model_info.id
282    attrs['description'] = model_info.description
283    attrs['category'] = model_info.category
284    attrs['is_structure_factor'] = model_info.structure_factor
285    attrs['is_form_factor'] = model_info.ER is not None
286    attrs['is_multiplicity_model'] = variants[0] > 1
287    attrs['multiplicity_info'] = variants
[fa5fd8d]288    attrs['orientation_params'] = tuple(orientation_params)
289    attrs['magnetic_params'] = tuple(magnetic_params)
290    attrs['fixed'] = tuple(fixed)
291    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]292    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]293
294    return attrs
[4d76711]295
[ce27e21]296class SasviewModel(object):
297    """
298    Sasview wrapper for opencl/ctypes model.
299    """
[fa5fd8d]300    # Model parameters for the specific model are set in the class constructor
301    # via the _generate_model_attributes function, which subclasses
302    # SasviewModel.  They are included here for typing and documentation
303    # purposes.
304    _model = None       # type: KernelModel
305    _model_info = None  # type: ModelInfo
306    #: load/save name for the model
307    id = None           # type: str
308    #: display name for the model
309    name = None         # type: str
310    #: short model description
311    description = None  # type: str
312    #: default model category
313    category = None     # type: str
314
315    #: names of the orientation parameters in the order they appear
[724257c]316    orientation_params = None # type: List[str]
[fa5fd8d]317    #: names of the magnetic parameters in the order they appear
[724257c]318    magnetic_params = None    # type: List[str]
[fa5fd8d]319    #: names of the fittable parameters
[724257c]320    fixed = None              # type: List[str]
[fa5fd8d]321    # TODO: the attribute fixed is ill-named
322
323    # Axis labels
324    input_name = "Q"
325    input_unit = "A^{-1}"
326    output_name = "Intensity"
327    output_unit = "cm^{-1}"
328
329    #: default cutoff for polydispersity
330    cutoff = 1e-5
331
332    # Note: Use non-mutable values for class attributes to avoid errors
333    #: parameters that are not fitted
334    non_fittable = ()        # type: Sequence[str]
335
336    #: True if model should appear as a structure factor
337    is_structure_factor = False
338    #: True if model should appear as a form factor
339    is_form_factor = False
340    #: True if model has multiplicity
341    is_multiplicity_model = False
[1f35235]342    #: Multiplicity information
[fa5fd8d]343    multiplicity_info = None # type: MultiplicityInfoType
344
345    # Per-instance variables
346    #: parameter {name: value} mapping
347    params = None      # type: Dict[str, float]
348    #: values for dispersion width, npts, nsigmas and type
349    dispersion = None  # type: Dict[str, Any]
350    #: units and limits for each parameter
[60f03de]351    details = None     # type: Dict[str, Sequence[Any]]
352    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]353    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]354    multiplicity = None     # type: Optional[int]
[04dc697]355    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
356    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]357
358    def __init__(self, multiplicity=None):
[04dc697]359        # type: (Optional[int]) -> None
[2622b3f]360
[04045f4]361        # TODO: _persistency_dict to persistency_dict throughout sasview
362        # TODO: refactor multiplicity to encompass variants
363        # TODO: dispersion should be a class
[fa5fd8d]364        # TODO: refactor multiplicity info
365        # TODO: separate profile view from multiplicity
366        # The button label, x and y axis labels and scale need to be under
367        # the control of the model, not the fit page.  Maximum flexibility,
368        # the fit page would supply the canvas and the profile could plot
369        # how it wants, but this assumes matplotlib.  Next level is that
370        # we provide some sort of data description including title, labels
371        # and lines to plot.
372
[1f35235]373        # Get the list of hidden parameters given the multiplicity
[04045f4]374        # Don't include multiplicity in the list of parameters
[fa5fd8d]375        self.multiplicity = multiplicity
[04045f4]376        if multiplicity is not None:
377            hidden = self._model_info.get_hidden_parameters(multiplicity)
378            hidden |= set([self.multiplicity_info.control])
379        else:
380            hidden = set()
[8f93522]381        if self._model_info.structure_factor:
382            hidden.add('scale')
383            hidden.add('background')
384            self._model_info.parameters.defaults['background'] = 0.
[04045f4]385
[04dc697]386        self._persistency_dict = {}
[fa5fd8d]387        self.params = collections.OrderedDict()
[b3a85cd]388        self.dispersion = collections.OrderedDict()
[fa5fd8d]389        self.details = {}
[8977226]390        for p in self._model_info.parameters.user_parameters({}, is2d=True):
[04045f4]391            if p.name in hidden:
[fa5fd8d]392                continue
[fcd7bbd]393            self.params[p.name] = p.default
[fa5fd8d]394            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]395            if p.polydisperse:
[fa5fd8d]396                self.details[p.id+".width"] = [
397                    "", 0.0, 1.0 if p.relative_pd else np.inf
398                ]
[fb5914f]399                self.dispersion[p.name] = {
400                    'width': 0,
401                    'npts': 35,
402                    'nsigmas': 3,
403                    'type': 'gaussian',
404                }
[ce27e21]405
[de97440]406    def __get_state__(self):
[fa5fd8d]407        # type: () -> Dict[str, Any]
[de97440]408        state = self.__dict__.copy()
[4d76711]409        state.pop('_model')
[de97440]410        # May need to reload model info on set state since it has pointers
411        # to python implementations of Iq, etc.
412        #state.pop('_model_info')
413        return state
414
415    def __set_state__(self, state):
[fa5fd8d]416        # type: (Dict[str, Any]) -> None
[de97440]417        self.__dict__ = state
[fb5914f]418        self._model = None
[de97440]419
[ce27e21]420    def __str__(self):
[fa5fd8d]421        # type: () -> str
[ce27e21]422        """
423        :return: string representation
424        """
425        return self.name
426
427    def is_fittable(self, par_name):
[fa5fd8d]428        # type: (str) -> bool
[ce27e21]429        """
430        Check if a given parameter is fittable or not
431
432        :param par_name: the parameter name to check
433        """
[e758662]434        return par_name in self.fixed
[ce27e21]435        #For the future
436        #return self.params[str(par_name)].is_fittable()
437
438
439    def getProfile(self):
[fa5fd8d]440        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]441        """
442        Get SLD profile
443
444        : return: (z, beta) where z is a list of depth of the transition points
445                beta is a list of the corresponding SLD values
446        """
[745b7bb]447        args = {} # type: Dict[str, Any]
[fa5fd8d]448        for p in self._model_info.parameters.kernel_parameters:
449            if p.id == self.multiplicity_info.control:
[745b7bb]450                value = float(self.multiplicity)
[fa5fd8d]451            elif p.length == 1:
[745b7bb]452                value = self.params.get(p.id, np.NaN)
[fa5fd8d]453            else:
[745b7bb]454                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]455                                  for k in range(1, p.length+1)])
[745b7bb]456            args[p.id] = value
457
[e7fe459]458        x, y = self._model_info.profile(**args)
459        return x, 1e-6*y
[ce27e21]460
461    def setParam(self, name, value):
[fa5fd8d]462        # type: (str, float) -> None
[ce27e21]463        """
464        Set the value of a model parameter
465
466        :param name: name of the parameter
467        :param value: value of the parameter
468
469        """
470        # Look for dispersion parameters
471        toks = name.split('.')
[de0c4ba]472        if len(toks) == 2:
[ce27e21]473            for item in self.dispersion.keys():
[e758662]474                if item == toks[0]:
[ce27e21]475                    for par in self.dispersion[item]:
[e758662]476                        if par == toks[1]:
[ce27e21]477                            self.dispersion[item][par] = value
478                            return
479        else:
480            # Look for standard parameter
481            for item in self.params.keys():
[e758662]482                if item == name:
[ce27e21]483                    self.params[item] = value
484                    return
485
[63b32bb]486        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]487
488    def getParam(self, name):
[fa5fd8d]489        # type: (str) -> float
[ce27e21]490        """
491        Set the value of a model parameter
492
493        :param name: name of the parameter
494
495        """
496        # Look for dispersion parameters
497        toks = name.split('.')
[de0c4ba]498        if len(toks) == 2:
[ce27e21]499            for item in self.dispersion.keys():
[e758662]500                if item == toks[0]:
[ce27e21]501                    for par in self.dispersion[item]:
[e758662]502                        if par == toks[1]:
[ce27e21]503                            return self.dispersion[item][par]
504        else:
505            # Look for standard parameter
506            for item in self.params.keys():
[e758662]507                if item == name:
[ce27e21]508                    return self.params[item]
509
[63b32bb]510        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]511
512    def getParamList(self):
[04dc697]513        # type: () -> Sequence[str]
[ce27e21]514        """
515        Return a list of all available parameters for the model
516        """
[04dc697]517        param_list = list(self.params.keys())
[ce27e21]518        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]519        param_list.extend(self.getDispParamList())
520        return param_list
[ce27e21]521
522    def getDispParamList(self):
[04dc697]523        # type: () -> Sequence[str]
[ce27e21]524        """
[fb5914f]525        Return a list of polydispersity parameters for the model
[ce27e21]526        """
[1780d59]527        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]528        ret = ['%s.%s' % (p_name, ext)
529               for p_name in self.dispersion.keys()
530               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]531        #print(ret)
[1780d59]532        return ret
[ce27e21]533
534    def clone(self):
[04dc697]535        # type: () -> "SasviewModel"
[ce27e21]536        """ Return a identical copy of self """
537        return deepcopy(self)
538
539    def run(self, x=0.0):
[fa5fd8d]540        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]541        """
542        Evaluate the model
543
544        :param x: input q, or [q,phi]
545
546        :return: scattering function P(q)
547
548        **DEPRECATED**: use calculate_Iq instead
549        """
[de0c4ba]550        if isinstance(x, (list, tuple)):
[3c56da87]551            # pylint: disable=unpacking-non-sequence
[ce27e21]552            q, phi = x
[60f03de]553            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]554        else:
[60f03de]555            return self.calculate_Iq([x])[0]
[ce27e21]556
557
558    def runXY(self, x=0.0):
[fa5fd8d]559        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]560        """
561        Evaluate the model in cartesian coordinates
562
563        :param x: input q, or [qx, qy]
564
565        :return: scattering function P(q)
566
567        **DEPRECATED**: use calculate_Iq instead
568        """
[de0c4ba]569        if isinstance(x, (list, tuple)):
[60f03de]570            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]571        else:
[60f03de]572            return self.calculate_Iq([x])[0]
[ce27e21]573
574    def evalDistribution(self, qdist):
[04dc697]575        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]576        r"""
[ce27e21]577        Evaluate a distribution of q-values.
578
[d138d43]579        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]580
[d138d43]581        * For 1D, a numpy array is expected as input
[ce27e21]582
[d138d43]583        ::
[ce27e21]584
[d138d43]585            evalDistribution(q)
[ce27e21]586
[d138d43]587          where *q* is a numpy array.
[ce27e21]588
[d138d43]589        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]590
[d138d43]591        ::
[ce27e21]592
[d138d43]593              qx = [ qx[0], qx[1], qx[2], ....]
594              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]595
[d138d43]596        If the model is 1D only, then
[ce27e21]597
[d138d43]598        .. math::
[ce27e21]599
[d138d43]600            q = \sqrt{q_x^2+q_y^2}
[ce27e21]601
602        """
[de0c4ba]603        if isinstance(qdist, (list, tuple)):
[ce27e21]604            # Check whether we have a list of ndarrays [qx,qy]
605            qx, qy = qdist
[05df1de]606            return self.calculate_Iq(qx, qy)
[ce27e21]607
608        elif isinstance(qdist, np.ndarray):
609            # We have a simple 1D distribution of q-values
610            return self.calculate_Iq(qdist)
611
612        else:
[3c56da87]613            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
614                            % type(qdist))
[ce27e21]615
[9dcb21d]616    def calc_composition_models(self, qx):
[64614ad]617        """
[9dcb21d]618        returns parts of the composition model or None if not a composition
619        model.
[64614ad]620        """
[946c8d27]621        # TODO: have calculate_Iq return the intermediates.
622        #
623        # The current interface causes calculate_Iq() to be called twice,
624        # once to get the combined result and again to get the intermediate
625        # results.  This is necessary for now.
626        # Long term, the solution is to change the interface to calculate_Iq
627        # so that it returns a results object containing all the bits:
[9644b5a]628        #     the A, B, C, ... of the composition model (and any subcomponents?)
[946c8d27]629        #     the P and S of the product model,
630        #     the combined model before resolution smearing,
631        #     the sasmodel before sesans conversion,
632        #     the oriented 2D model used to fit oriented usans data,
633        #     the final I(q),
634        #     ...
[9644b5a]635        #
[946c8d27]636        # Have the model calculator add all of these blindly to the data
637        # tree, and update the graphs which contain them.  The fitter
638        # needs to be updated to use the I(q) value only, ignoring the rest.
639        #
640        # The simple fix of returning the existing intermediate results
641        # will not work for a couple of reasons: (1) another thread may
642        # sneak in to compute its own results before calc_composition_models
643        # is called, and (2) calculate_Iq is currently called three times:
644        # once with q, once with q values before qmin and once with q values
645        # after q max.  Both of these should be addressed before
646        # replacing this code.
[9644b5a]647        composition = self._model_info.composition
648        if composition and composition[0] == 'product': # only P*S for now
649            with calculation_lock:
650                self._calculate_Iq(qx)
651                return self._intermediate_results
652        else:
653            return None
[bf8c271]654
[fa5fd8d]655    def calculate_Iq(self, qx, qy=None):
656        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]657        """
658        Calculate Iq for one set of q with the current parameters.
659
660        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
661
662        This should NOT be used for fitting since it copies the *q* vectors
663        to the card for each evaluation.
664        """
[a38b065]665        ## uncomment the following when trying to debug the uncoordinated calls
666        ## to calculate_Iq
667        #if calculation_lock.locked():
[724257c]668        #    logger.info("calculation waiting for another thread to complete")
669        #    logger.info("\n".join(traceback.format_stack()))
[a38b065]670
671        with calculation_lock:
672            return self._calculate_Iq(qx, qy)
673
674    def _calculate_Iq(self, qx, qy=None):
[fb5914f]675        if self._model is None:
[d2bb604]676            self._model = core.build_model(self._model_info)
[fa5fd8d]677        if qy is not None:
678            q_vectors = [np.asarray(qx), np.asarray(qy)]
679        else:
680            q_vectors = [np.asarray(qx)]
[a738209]681        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]682        parameters = self._model_info.parameters
683        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]684        #weights.plot_weights(self._model_info, pairs)
[bde38b5]685        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]686        #call_details.show()
[05df1de]687        #print("================ parameters ==================")
688        #for p, v in zip(parameters.call_parameters, pairs): print(p.name, v[0])
[ce99754]689        #for k, p in enumerate(self._model_info.parameters.call_parameters):
690        #    print(k, p.name, *pairs[k])
[4edec6f]691        #print("params", self.params)
692        #print("values", values)
693        #print("is_mag", is_magnetic)
[6a0d6aa]694        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]695                            magnetic=is_magnetic)
[ce99754]696        #print("result", result)
[bf8c271]697        self._intermediate_results = getattr(calculator, 'results', None)
[a738209]698        calculator.release()
[d533590]699        #self._model.release()
[ce27e21]700        return result
701
702    def calculate_ER(self):
[fa5fd8d]703        # type: () -> float
[ce27e21]704        """
705        Calculate the effective radius for P(q)*S(q)
706
707        :return: the value of the effective radius
708        """
[4bfd277]709        if self._model_info.ER is None:
[ce27e21]710            return 1.0
711        else:
[4bfd277]712            value, weight = self._dispersion_mesh()
713            fv = self._model_info.ER(*value)
[9404dd3]714            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]715            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]716
717    def calculate_VR(self):
[fa5fd8d]718        # type: () -> float
[ce27e21]719        """
720        Calculate the volf ratio for P(q)*S(q)
721
722        :return: the value of the volf ratio
723        """
[4bfd277]724        if self._model_info.VR is None:
[ce27e21]725            return 1.0
726        else:
[4bfd277]727            value, weight = self._dispersion_mesh()
728            whole, part = self._model_info.VR(*value)
729            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]730
731    def set_dispersion(self, parameter, dispersion):
[7c3fb15]732        # type: (str, weights.Dispersion) -> None
[ce27e21]733        """
734        Set the dispersion object for a model parameter
735
736        :param parameter: name of the parameter [string]
737        :param dispersion: dispersion object of type Dispersion
738        """
[fa800e72]739        if parameter in self.params:
[1780d59]740            # TODO: Store the disperser object directly in the model.
[56b2687]741            # The current method of relying on the sasview GUI to
[fa800e72]742            # remember them is kind of funky.
[1780d59]743            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]744            # (1) Could create a sasview model that has not yet been
[1780d59]745            # converted, assign the disperser to one of its polydisperse
746            # parameters, then retrieve the disperser parameters from the
[9c1a59c]747            # sasview model.
748            # (2) Could write a disperser parameter retriever in sasview.
749            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]750            # For now, rely on the fact that the sasview only ever uses
751            # new dispersers in the set_dispersion call and create a new
752            # one instead of trying to assign parameters.
[ce27e21]753            self.dispersion[parameter] = dispersion.get_pars()
754        else:
[7c3fb15]755            raise ValueError("%r is not a dispersity or orientation parameter"
756                             % parameter)
[ce27e21]757
[aa4946b]758    def _dispersion_mesh(self):
[fa5fd8d]759        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]760        """
761        Create a mesh grid of dispersion parameters and weights.
762
763        Returns [p1,p2,...],w where pj is a vector of values for parameter j
764        and w is a vector containing the products for weights for each
765        parameter set in the vector.
766        """
[4bfd277]767        pars = [self._get_weights(p)
768                for p in self._model_info.parameters.call_parameters
769                if p.type == 'volume']
[9eb3632]770        return dispersion_mesh(self._model_info, pars)
[ce27e21]771
772    def _get_weights(self, par):
[fa5fd8d]773        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]774        """
[fb5914f]775        Return dispersion weights for parameter
[de0c4ba]776        """
[fa5fd8d]777        if par.name not in self.params:
778            if par.name == self.multiplicity_info.control:
[32f87a5]779                return self.multiplicity, [self.multiplicity], [1.0]
[fa5fd8d]780            else:
[17db833]781                # For hidden parameters use default values.  This sets
782                # scale=1 and background=0 for structure factors
783                default = self._model_info.parameters.defaults.get(par.name, np.NaN)
784                return default, [default], [1.0]
[fa5fd8d]785        elif par.polydisperse:
[32f87a5]786            value = self.params[par.name]
[fb5914f]787            dis = self.dispersion[par.name]
[9c1a59c]788            if dis['type'] == 'array':
[32f87a5]789                dispersity, weight = dis['values'], dis['weights']
[9c1a59c]790            else:
[32f87a5]791                dispersity, weight = weights.get_weights(
[9c1a59c]792                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
[32f87a5]793                    value, par.limits, par.relative_pd)
794            return value, dispersity, weight
[fb5914f]795        else:
[32f87a5]796            value = self.params[par.name]
[ce99754]797            return value, [value], [1.0]
[ce27e21]798
[749a7d4]799def test_cylinder():
[fa5fd8d]800    # type: () -> float
[4d76711]801    """
[749a7d4]802    Test that the cylinder model runs, returning the value at [0.1,0.1].
[4d76711]803    """
804    Cylinder = _make_standard_model('cylinder')
[fb5914f]805    cylinder = Cylinder()
[b32dafd]806    return cylinder.evalDistribution([0.1, 0.1])
[de97440]807
[8f93522]808def test_structure_factor():
809    # type: () -> float
810    """
[749a7d4]811    Test that 2-D hardsphere model runs and doesn't produce NaN.
[8f93522]812    """
813    Model = _make_standard_model('hardsphere')
814    model = Model()
[17db833]815    value2d = model.evalDistribution([0.1, 0.1])
816    value1d = model.evalDistribution(np.array([0.1*np.sqrt(2)]))
817    #print("hardsphere", value1d, value2d)
818    if np.isnan(value1d) or np.isnan(value2d):
819        raise ValueError("hardsphere returns nan")
[8f93522]820
[ce99754]821def test_product():
822    # type: () -> float
823    """
824    Test that 2-D hardsphere model runs and doesn't produce NaN.
825    """
826    S = _make_standard_model('hayter_msa')()
827    P = _make_standard_model('cylinder')()
828    model = MultiplicationModel(P, S)
829    value = model.evalDistribution([0.1, 0.1])
830    if np.isnan(value):
831        raise ValueError("cylinder*hatyer_msa returns null")
832
[04045f4]833def test_rpa():
834    # type: () -> float
835    """
[749a7d4]836    Test that the 2-D RPA model runs
[04045f4]837    """
838    RPA = _make_standard_model('rpa')
839    rpa = RPA(3)
[b32dafd]840    return rpa.evalDistribution([0.1, 0.1])
[04045f4]841
[749a7d4]842def test_empty_distribution():
843    # type: () -> None
844    """
845    Make sure that sasmodels returns NaN when there are no polydispersity points
846    """
847    Cylinder = _make_standard_model('cylinder')
848    cylinder = Cylinder()
849    cylinder.setParam('radius', -1.0)
850    cylinder.setParam('background', 0.)
851    Iq = cylinder.evalDistribution(np.asarray([0.1]))
[2d81cfe]852    assert Iq[0] == 0., "empty distribution fails"
[4d76711]853
854def test_model_list():
[fa5fd8d]855    # type: () -> None
[4d76711]856    """
[749a7d4]857    Make sure that all models build as sasview models
[4d76711]858    """
859    from .exception import annotate_exception
860    for name in core.list_models():
861        try:
862            _make_standard_model(name)
863        except:
864            annotate_exception("when loading "+name)
865            raise
866
[c95dfc63]867def test_old_name():
868    # type: () -> None
869    """
[a69d8cd]870    Load and run cylinder model as sas-models-CylinderModel
[c95dfc63]871    """
872    if not SUPPORT_OLD_STYLE_PLUGINS:
873        return
874    try:
875        # if sasview is not on the path then don't try to test it
876        import sas
877    except ImportError:
878        return
879    load_standard_models()
880    from sas.models.CylinderModel import CylinderModel
881    CylinderModel().evalDistribution([0.1, 0.1])
882
[05df1de]883def magnetic_demo():
884    Model = _make_standard_model('sphere')
885    model = Model()
886    model.setParam('M0:sld', 8)
887    q = np.linspace(-0.35, 0.35, 500)
888    qx, qy = np.meshgrid(q, q)
889    result = model.calculate_Iq(qx.flatten(), qy.flatten())
890    result = result.reshape(qx.shape)
891
892    import pylab
893    pylab.imshow(np.log(result + 0.001))
894    pylab.show()
895
[fb5914f]896if __name__ == "__main__":
[749a7d4]897    print("cylinder(0.1,0.1)=%g"%test_cylinder())
[05df1de]898    #magnetic_demo()
[ce99754]899    #test_product()
[17db833]900    #test_structure_factor()
901    #print("rpa:", test_rpa())
[749a7d4]902    #test_empty_distribution()
Note: See TracBrowser for help on using the repository browser.