source: sasmodels/sasmodels/sasview_model.py @ a80e64c

core_shell_microgelscostrafo411magnetic_modelticket-1257-vesicle-productticket_1156ticket_1265_superballticket_822_more_unit_tests
Last change on this file since a80e64c was a80e64c, checked in by Paul Kienzle <pkienzle@…>, 8 years ago

add sasview compatible MultiplicationModel? interface to product model

  • Property mode set to 100644
File size: 26.4 KB
RevLine 
[87985ca]1"""
2Sasview model constructor.
3
4Given a module defining an OpenCL kernel such as sasmodels.models.cylinder,
5create a sasview model class to run that kernel as follows::
6
[92d38285]7    from sasmodels.sasview_model import load_custom_model
8    CylinderModel = load_custom_model('sasmodels/models/cylinder.py')
[87985ca]9"""
[4d76711]10from __future__ import print_function
[87985ca]11
[ce27e21]12import math
13from copy import deepcopy
[2622b3f]14import collections
[4d76711]15import traceback
16import logging
[9457498]17from os.path import basename, splitext
[ce27e21]18
[7ae2b7f]19import numpy as np  # type: ignore
[ce27e21]20
[aa4946b]21from . import core
[4d76711]22from . import custom
[a80e64c]23from . import product
[72a081d]24from . import generate
[fb5914f]25from . import weights
[6d6508e]26from . import modelinfo
[bde38b5]27from .details import make_kernel_args, dispersion_mesh
[ff7119b]28
[fa5fd8d]29try:
[60f03de]30    from typing import Dict, Mapping, Any, Sequence, Tuple, NamedTuple, List, Optional, Union, Callable
[fa5fd8d]31    from .modelinfo import ModelInfo, Parameter
32    from .kernel import KernelModel
33    MultiplicityInfoType = NamedTuple(
34        'MuliplicityInfo',
35        [("number", int), ("control", str), ("choices", List[str]),
36         ("x_axis_label", str)])
[60f03de]37    SasviewModelType = Callable[[int], "SasviewModel"]
[fa5fd8d]38except ImportError:
39    pass
40
[c95dfc63]41SUPPORT_OLD_STYLE_PLUGINS = True
42
43def _register_old_models():
44    # type: () -> None
45    """
46    Place the new models into sasview under the old names.
47
48    Monkey patch sas.sascalc.fit as sas.models so that sas.models.pluginmodel
49    is available to the plugin modules.
50    """
51    import sys
52    import sas
53    import sas.sascalc.fit
54    sys.modules['sas.models'] = sas.sascalc.fit
55    sas.models = sas.sascalc.fit
56
57    import sas.models
58    from sasmodels.conversion_table import CONVERSION_TABLE
59    for new_name, conversion in CONVERSION_TABLE.items():
60        old_name = conversion[0]
[fd19811]61        module_attrs = {old_name: find_model(new_name)}
[c95dfc63]62        ConstructedModule = type(old_name, (), module_attrs)
63        old_path = 'sas.models.' + old_name
64        setattr(sas.models, old_path, ConstructedModule)
65        sys.modules[old_path] = ConstructedModule
66
[9457498]67
[fa5fd8d]68# TODO: separate x_axis_label from multiplicity info
69MultiplicityInfo = collections.namedtuple(
70    'MultiplicityInfo',
71    ["number", "control", "choices", "x_axis_label"],
72)
73
[92d38285]74MODELS = {}
75def find_model(modelname):
[b32dafd]76    # type: (str) -> SasviewModelType
77    """
78    Find a model by name.  If the model name ends in py, try loading it from
79    custom models, otherwise look for it in the list of builtin models.
80    """
[92d38285]81    # TODO: used by sum/product model to load an existing model
82    # TODO: doesn't handle custom models properly
83    if modelname.endswith('.py'):
84        return load_custom_model(modelname)
85    elif modelname in MODELS:
86        return MODELS[modelname]
87    else:
88        raise ValueError("unknown model %r"%modelname)
89
[56b2687]90
[fa5fd8d]91# TODO: figure out how to say that the return type is a subclass
[4d76711]92def load_standard_models():
[60f03de]93    # type: () -> List[SasviewModelType]
[4d76711]94    """
95    Load and return the list of predefined models.
96
97    If there is an error loading a model, then a traceback is logged and the
98    model is not returned.
99    """
100    models = []
101    for name in core.list_models():
102        try:
[92d38285]103            MODELS[name] = _make_standard_model(name)
104            models.append(MODELS[name])
[ee8f734]105        except Exception:
[4d76711]106            logging.error(traceback.format_exc())
[c95dfc63]107    if SUPPORT_OLD_STYLE_PLUGINS:
108        _register_old_models()
109
[4d76711]110    return models
[de97440]111
[4d76711]112
113def load_custom_model(path):
[60f03de]114    # type: (str) -> SasviewModelType
[4d76711]115    """
116    Load a custom model given the model path.
[ff7119b]117    """
[4d76711]118    kernel_module = custom.load_custom_kernel_module(path)
[92d38285]119    try:
120        model = kernel_module.Model
[9457498]121        # Old style models do not set the name in the class attributes, so
122        # set it here; this name will be overridden when the object is created
123        # with an instance variable that has the same value.
124        if model.name == "":
125            model.name = splitext(basename(path))[0]
[20a70bc]126        if not hasattr(model, 'filename'):
127            model.filename = kernel_module.__file__
[64f0a1c]128            # For old models, treat .pyc and .py files interchangeably.
129            # This is needed because of the Sum|Multi(p1,p2) types of models
130            # and the convoluted way in which they are created.
131            if model.filename.endswith(".py"):
132                logging.info("Loading %s as .pyc", model.filename)
133                model.filename = model.filename+'c'
[e4bf271]134        if not hasattr(model, 'id'):
135            model.id = splitext(basename(model.filename))[0]
[92d38285]136    except AttributeError:
[56b2687]137        model_info = modelinfo.make_model_info(kernel_module)
[92d38285]138        model = _make_model_from_info(model_info)
[ed10b57]139
[2f2c70c]140    # If a model name already exists and we are loading a different model,
141    # use the model file name as the model name.
142    if model.name in MODELS and not model.filename == MODELS[model.name].filename:
143        _previous_name = model.name
144        model.name = model.id
145       
146        # If the new model name is still in the model list (for instance,
147        # if we put a cylinder.py in our plug-in directory), then append
148        # an identifier.
149        if model.name in MODELS and not model.filename == MODELS[model.name].filename:
150            model.name = model.id + '_user'
151        logging.info("Model %s already exists: using %s [%s]", _previous_name, model.name, model.filename)
[ed10b57]152
[92d38285]153    MODELS[model.name] = model
154    return model
[4d76711]155
[87985ca]156
[4d76711]157def _make_standard_model(name):
[60f03de]158    # type: (str) -> SasviewModelType
[ff7119b]159    """
[4d76711]160    Load the sasview model defined by *name*.
[72a081d]161
[4d76711]162    *name* can be a standard model name or a path to a custom model.
[87985ca]163
[4d76711]164    Returns a class that can be used directly as a sasview model.
[ff7119b]165    """
[4d76711]166    kernel_module = generate.load_kernel_module(name)
[fa5fd8d]167    model_info = modelinfo.make_model_info(kernel_module)
[4d76711]168    return _make_model_from_info(model_info)
[72a081d]169
170
[a80e64c]171def MultiplicationModel(form_factor, structure_factor):
172    # type: ("SasviewModel", "SasviewModel") -> "SasviewModel"
173    model_info = product.make_product_info(form_factor._model_info,
174                                           structure_factor._model_info)
175    ConstructedModel = _make_model_from_info(model_info)
176    return ConstructedModel()
177
[4d76711]178def _make_model_from_info(model_info):
[60f03de]179    # type: (ModelInfo) -> SasviewModelType
[4d76711]180    """
181    Convert *model_info* into a SasView model wrapper.
182    """
[fa5fd8d]183    def __init__(self, multiplicity=None):
184        SasviewModel.__init__(self, multiplicity=multiplicity)
185    attrs = _generate_model_attributes(model_info)
186    attrs['__init__'] = __init__
[2f2c70c]187    attrs['filename'] = model_info.filename
[60f03de]188    ConstructedModel = type(model_info.name, (SasviewModel,), attrs) # type: SasviewModelType
[ce27e21]189    return ConstructedModel
190
[fa5fd8d]191def _generate_model_attributes(model_info):
192    # type: (ModelInfo) -> Dict[str, Any]
193    """
194    Generate the class attributes for the model.
195
196    This should include all the information necessary to query the model
197    details so that you do not need to instantiate a model to query it.
198
199    All the attributes should be immutable to avoid accidents.
200    """
201
202    # TODO: allow model to override axis labels input/output name/unit
203
[a18c5b3]204    # Process multiplicity
[fa5fd8d]205    non_fittable = []  # type: List[str]
[04045f4]206    xlabel = model_info.profile_axes[0] if model_info.profile is not None else ""
207    variants = MultiplicityInfo(0, "", [], xlabel)
[a18c5b3]208    for p in model_info.parameters.kernel_parameters:
[04045f4]209        if p.name == model_info.control:
[fa5fd8d]210            non_fittable.append(p.name)
[04045f4]211            variants = MultiplicityInfo(
[ce176ca]212                len(p.choices) if p.choices else int(p.limits[1]),
213                p.name, p.choices, xlabel
[fa5fd8d]214            )
215            break
216
[50ec515]217    # Only a single drop-down list parameter available
218    fun_list = []
219    for p in model_info.parameters.kernel_parameters:
220        if p.choices:
221            fun_list = p.choices
222            if p.length > 1:
223                non_fittable.extend(p.id+str(k) for k in range(1, p.length+1))
224            break
225
[a18c5b3]226    # Organize parameter sets
[fa5fd8d]227    orientation_params = []
228    magnetic_params = []
229    fixed = []
[a18c5b3]230    for p in model_info.parameters.user_parameters():
[fa5fd8d]231        if p.type == 'orientation':
232            orientation_params.append(p.name)
233            orientation_params.append(p.name+".width")
234            fixed.append(p.name+".width")
[32e3c9b]235        elif p.type == 'magnetic':
[fa5fd8d]236            orientation_params.append(p.name)
237            magnetic_params.append(p.name)
238            fixed.append(p.name+".width")
[a18c5b3]239
[32e3c9b]240
[a18c5b3]241    # Build class dictionary
242    attrs = {}  # type: Dict[str, Any]
243    attrs['_model_info'] = model_info
244    attrs['name'] = model_info.name
245    attrs['id'] = model_info.id
246    attrs['description'] = model_info.description
247    attrs['category'] = model_info.category
248    attrs['is_structure_factor'] = model_info.structure_factor
249    attrs['is_form_factor'] = model_info.ER is not None
250    attrs['is_multiplicity_model'] = variants[0] > 1
251    attrs['multiplicity_info'] = variants
[fa5fd8d]252    attrs['orientation_params'] = tuple(orientation_params)
253    attrs['magnetic_params'] = tuple(magnetic_params)
254    attrs['fixed'] = tuple(fixed)
255    attrs['non_fittable'] = tuple(non_fittable)
[50ec515]256    attrs['fun_list'] = tuple(fun_list)
[fa5fd8d]257
258    return attrs
[4d76711]259
[ce27e21]260class SasviewModel(object):
261    """
262    Sasview wrapper for opencl/ctypes model.
263    """
[fa5fd8d]264    # Model parameters for the specific model are set in the class constructor
265    # via the _generate_model_attributes function, which subclasses
266    # SasviewModel.  They are included here for typing and documentation
267    # purposes.
268    _model = None       # type: KernelModel
269    _model_info = None  # type: ModelInfo
270    #: load/save name for the model
271    id = None           # type: str
272    #: display name for the model
273    name = None         # type: str
274    #: short model description
275    description = None  # type: str
276    #: default model category
277    category = None     # type: str
278
279    #: names of the orientation parameters in the order they appear
280    orientation_params = None # type: Sequence[str]
281    #: names of the magnetic parameters in the order they appear
282    magnetic_params = None    # type: Sequence[str]
283    #: names of the fittable parameters
284    fixed = None              # type: Sequence[str]
285    # TODO: the attribute fixed is ill-named
286
287    # Axis labels
288    input_name = "Q"
289    input_unit = "A^{-1}"
290    output_name = "Intensity"
291    output_unit = "cm^{-1}"
292
293    #: default cutoff for polydispersity
294    cutoff = 1e-5
295
296    # Note: Use non-mutable values for class attributes to avoid errors
297    #: parameters that are not fitted
298    non_fittable = ()        # type: Sequence[str]
299
300    #: True if model should appear as a structure factor
301    is_structure_factor = False
302    #: True if model should appear as a form factor
303    is_form_factor = False
304    #: True if model has multiplicity
305    is_multiplicity_model = False
306    #: Mulitplicity information
307    multiplicity_info = None # type: MultiplicityInfoType
308
309    # Per-instance variables
310    #: parameter {name: value} mapping
311    params = None      # type: Dict[str, float]
312    #: values for dispersion width, npts, nsigmas and type
313    dispersion = None  # type: Dict[str, Any]
314    #: units and limits for each parameter
[60f03de]315    details = None     # type: Dict[str, Sequence[Any]]
316    #                  # actual type is Dict[str, List[str, float, float]]
[04dc697]317    #: multiplicity value, or None if no multiplicity on the model
[fa5fd8d]318    multiplicity = None     # type: Optional[int]
[04dc697]319    #: memory for polydispersity array if using ArrayDispersion (used by sasview).
320    _persistency_dict = None # type: Dict[str, Tuple[np.ndarray, np.ndarray]]
[fa5fd8d]321
322    def __init__(self, multiplicity=None):
[04dc697]323        # type: (Optional[int]) -> None
[2622b3f]324
[04045f4]325        # TODO: _persistency_dict to persistency_dict throughout sasview
326        # TODO: refactor multiplicity to encompass variants
327        # TODO: dispersion should be a class
[fa5fd8d]328        # TODO: refactor multiplicity info
329        # TODO: separate profile view from multiplicity
330        # The button label, x and y axis labels and scale need to be under
331        # the control of the model, not the fit page.  Maximum flexibility,
332        # the fit page would supply the canvas and the profile could plot
333        # how it wants, but this assumes matplotlib.  Next level is that
334        # we provide some sort of data description including title, labels
335        # and lines to plot.
336
[04045f4]337        # Get the list of hidden parameters given the mulitplicity
338        # Don't include multiplicity in the list of parameters
[fa5fd8d]339        self.multiplicity = multiplicity
[04045f4]340        if multiplicity is not None:
341            hidden = self._model_info.get_hidden_parameters(multiplicity)
342            hidden |= set([self.multiplicity_info.control])
343        else:
344            hidden = set()
[8f93522]345        if self._model_info.structure_factor:
346            hidden.add('scale')
347            hidden.add('background')
348            self._model_info.parameters.defaults['background'] = 0.
[04045f4]349
[04dc697]350        self._persistency_dict = {}
[fa5fd8d]351        self.params = collections.OrderedDict()
[b3a85cd]352        self.dispersion = collections.OrderedDict()
[fa5fd8d]353        self.details = {}
[04045f4]354        for p in self._model_info.parameters.user_parameters():
355            if p.name in hidden:
[fa5fd8d]356                continue
[fcd7bbd]357            self.params[p.name] = p.default
[fa5fd8d]358            self.details[p.id] = [p.units, p.limits[0], p.limits[1]]
[fb5914f]359            if p.polydisperse:
[fa5fd8d]360                self.details[p.id+".width"] = [
361                    "", 0.0, 1.0 if p.relative_pd else np.inf
362                ]
[fb5914f]363                self.dispersion[p.name] = {
364                    'width': 0,
365                    'npts': 35,
366                    'nsigmas': 3,
367                    'type': 'gaussian',
368                }
[ce27e21]369
[de97440]370    def __get_state__(self):
[fa5fd8d]371        # type: () -> Dict[str, Any]
[de97440]372        state = self.__dict__.copy()
[4d76711]373        state.pop('_model')
[de97440]374        # May need to reload model info on set state since it has pointers
375        # to python implementations of Iq, etc.
376        #state.pop('_model_info')
377        return state
378
379    def __set_state__(self, state):
[fa5fd8d]380        # type: (Dict[str, Any]) -> None
[de97440]381        self.__dict__ = state
[fb5914f]382        self._model = None
[de97440]383
[ce27e21]384    def __str__(self):
[fa5fd8d]385        # type: () -> str
[ce27e21]386        """
387        :return: string representation
388        """
389        return self.name
390
391    def is_fittable(self, par_name):
[fa5fd8d]392        # type: (str) -> bool
[ce27e21]393        """
394        Check if a given parameter is fittable or not
395
396        :param par_name: the parameter name to check
397        """
[e758662]398        return par_name in self.fixed
[ce27e21]399        #For the future
400        #return self.params[str(par_name)].is_fittable()
401
402
403    def getProfile(self):
[fa5fd8d]404        # type: () -> (np.ndarray, np.ndarray)
[ce27e21]405        """
406        Get SLD profile
407
408        : return: (z, beta) where z is a list of depth of the transition points
409                beta is a list of the corresponding SLD values
410        """
[745b7bb]411        args = {} # type: Dict[str, Any]
[fa5fd8d]412        for p in self._model_info.parameters.kernel_parameters:
413            if p.id == self.multiplicity_info.control:
[745b7bb]414                value = float(self.multiplicity)
[fa5fd8d]415            elif p.length == 1:
[745b7bb]416                value = self.params.get(p.id, np.NaN)
[fa5fd8d]417            else:
[745b7bb]418                value = np.array([self.params.get(p.id+str(k), np.NaN)
[b32dafd]419                                  for k in range(1, p.length+1)])
[745b7bb]420            args[p.id] = value
421
[e7fe459]422        x, y = self._model_info.profile(**args)
423        return x, 1e-6*y
[ce27e21]424
425    def setParam(self, name, value):
[fa5fd8d]426        # type: (str, float) -> None
[ce27e21]427        """
428        Set the value of a model parameter
429
430        :param name: name of the parameter
431        :param value: value of the parameter
432
433        """
434        # Look for dispersion parameters
435        toks = name.split('.')
[de0c4ba]436        if len(toks) == 2:
[ce27e21]437            for item in self.dispersion.keys():
[e758662]438                if item == toks[0]:
[ce27e21]439                    for par in self.dispersion[item]:
[e758662]440                        if par == toks[1]:
[ce27e21]441                            self.dispersion[item][par] = value
442                            return
443        else:
444            # Look for standard parameter
445            for item in self.params.keys():
[e758662]446                if item == name:
[ce27e21]447                    self.params[item] = value
448                    return
449
[63b32bb]450        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]451
452    def getParam(self, name):
[fa5fd8d]453        # type: (str) -> float
[ce27e21]454        """
455        Set the value of a model parameter
456
457        :param name: name of the parameter
458
459        """
460        # Look for dispersion parameters
461        toks = name.split('.')
[de0c4ba]462        if len(toks) == 2:
[ce27e21]463            for item in self.dispersion.keys():
[e758662]464                if item == toks[0]:
[ce27e21]465                    for par in self.dispersion[item]:
[e758662]466                        if par == toks[1]:
[ce27e21]467                            return self.dispersion[item][par]
468        else:
469            # Look for standard parameter
470            for item in self.params.keys():
[e758662]471                if item == name:
[ce27e21]472                    return self.params[item]
473
[63b32bb]474        raise ValueError("Model does not contain parameter %s" % name)
[ce27e21]475
476    def getParamList(self):
[04dc697]477        # type: () -> Sequence[str]
[ce27e21]478        """
479        Return a list of all available parameters for the model
480        """
[04dc697]481        param_list = list(self.params.keys())
[ce27e21]482        # WARNING: Extending the list with the dispersion parameters
[de0c4ba]483        param_list.extend(self.getDispParamList())
484        return param_list
[ce27e21]485
486    def getDispParamList(self):
[04dc697]487        # type: () -> Sequence[str]
[ce27e21]488        """
[fb5914f]489        Return a list of polydispersity parameters for the model
[ce27e21]490        """
[1780d59]491        # TODO: fix test so that parameter order doesn't matter
[3bcb88c]492        ret = ['%s.%s' % (p_name, ext)
493               for p_name in self.dispersion.keys()
494               for ext in ('npts', 'nsigmas', 'width')]
[9404dd3]495        #print(ret)
[1780d59]496        return ret
[ce27e21]497
498    def clone(self):
[04dc697]499        # type: () -> "SasviewModel"
[ce27e21]500        """ Return a identical copy of self """
501        return deepcopy(self)
502
503    def run(self, x=0.0):
[fa5fd8d]504        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]505        """
506        Evaluate the model
507
508        :param x: input q, or [q,phi]
509
510        :return: scattering function P(q)
511
512        **DEPRECATED**: use calculate_Iq instead
513        """
[de0c4ba]514        if isinstance(x, (list, tuple)):
[3c56da87]515            # pylint: disable=unpacking-non-sequence
[ce27e21]516            q, phi = x
[60f03de]517            return self.calculate_Iq([q*math.cos(phi)], [q*math.sin(phi)])[0]
[ce27e21]518        else:
[60f03de]519            return self.calculate_Iq([x])[0]
[ce27e21]520
521
522    def runXY(self, x=0.0):
[fa5fd8d]523        # type: (Union[float, (float, float), List[float]]) -> float
[ce27e21]524        """
525        Evaluate the model in cartesian coordinates
526
527        :param x: input q, or [qx, qy]
528
529        :return: scattering function P(q)
530
531        **DEPRECATED**: use calculate_Iq instead
532        """
[de0c4ba]533        if isinstance(x, (list, tuple)):
[60f03de]534            return self.calculate_Iq([x[0]], [x[1]])[0]
[ce27e21]535        else:
[60f03de]536            return self.calculate_Iq([x])[0]
[ce27e21]537
538    def evalDistribution(self, qdist):
[04dc697]539        # type: (Union[np.ndarray, Tuple[np.ndarray, np.ndarray], List[np.ndarray]]) -> np.ndarray
[d138d43]540        r"""
[ce27e21]541        Evaluate a distribution of q-values.
542
[d138d43]543        :param qdist: array of q or a list of arrays [qx,qy]
[ce27e21]544
[d138d43]545        * For 1D, a numpy array is expected as input
[ce27e21]546
[d138d43]547        ::
[ce27e21]548
[d138d43]549            evalDistribution(q)
[ce27e21]550
[d138d43]551          where *q* is a numpy array.
[ce27e21]552
[d138d43]553        * For 2D, a list of *[qx,qy]* is expected with 1D arrays as input
[ce27e21]554
[d138d43]555        ::
[ce27e21]556
[d138d43]557              qx = [ qx[0], qx[1], qx[2], ....]
558              qy = [ qy[0], qy[1], qy[2], ....]
[ce27e21]559
[d138d43]560        If the model is 1D only, then
[ce27e21]561
[d138d43]562        .. math::
[ce27e21]563
[d138d43]564            q = \sqrt{q_x^2+q_y^2}
[ce27e21]565
566        """
[de0c4ba]567        if isinstance(qdist, (list, tuple)):
[ce27e21]568            # Check whether we have a list of ndarrays [qx,qy]
569            qx, qy = qdist
[6d6508e]570            if not self._model_info.parameters.has_2d:
[de0c4ba]571                return self.calculate_Iq(np.sqrt(qx ** 2 + qy ** 2))
[5d4777d]572            else:
573                return self.calculate_Iq(qx, qy)
[ce27e21]574
575        elif isinstance(qdist, np.ndarray):
576            # We have a simple 1D distribution of q-values
577            return self.calculate_Iq(qdist)
578
579        else:
[3c56da87]580            raise TypeError("evalDistribution expects q or [qx, qy], not %r"
581                            % type(qdist))
[ce27e21]582
[fa5fd8d]583    def calculate_Iq(self, qx, qy=None):
584        # type: (Sequence[float], Optional[Sequence[float]]) -> np.ndarray
[ff7119b]585        """
586        Calculate Iq for one set of q with the current parameters.
587
588        If the model is 1D, use *q*.  If 2D, use *qx*, *qy*.
589
590        This should NOT be used for fitting since it copies the *q* vectors
591        to the card for each evaluation.
592        """
[6a0d6aa]593        #core.HAVE_OPENCL = False
[fb5914f]594        if self._model is None:
[d2bb604]595            self._model = core.build_model(self._model_info)
[fa5fd8d]596        if qy is not None:
597            q_vectors = [np.asarray(qx), np.asarray(qy)]
598        else:
599            q_vectors = [np.asarray(qx)]
[a738209]600        calculator = self._model.make_kernel(q_vectors)
[6a0d6aa]601        parameters = self._model_info.parameters
602        pairs = [self._get_weights(p) for p in parameters.call_parameters]
[9c1a59c]603        #weights.plot_weights(self._model_info, pairs)
[bde38b5]604        call_details, values, is_magnetic = make_kernel_args(calculator, pairs)
[4edec6f]605        #call_details.show()
606        #print("pairs", pairs)
607        #print("params", self.params)
608        #print("values", values)
609        #print("is_mag", is_magnetic)
[6a0d6aa]610        result = calculator(call_details, values, cutoff=self.cutoff,
[9eb3632]611                            magnetic=is_magnetic)
[a738209]612        calculator.release()
[9f37726]613        self._model.release()
[ce27e21]614        return result
615
616    def calculate_ER(self):
[fa5fd8d]617        # type: () -> float
[ce27e21]618        """
619        Calculate the effective radius for P(q)*S(q)
620
621        :return: the value of the effective radius
622        """
[4bfd277]623        if self._model_info.ER is None:
[ce27e21]624            return 1.0
625        else:
[4bfd277]626            value, weight = self._dispersion_mesh()
627            fv = self._model_info.ER(*value)
[9404dd3]628            #print(values[0].shape, weights.shape, fv.shape)
[4bfd277]629            return np.sum(weight * fv) / np.sum(weight)
[ce27e21]630
631    def calculate_VR(self):
[fa5fd8d]632        # type: () -> float
[ce27e21]633        """
634        Calculate the volf ratio for P(q)*S(q)
635
636        :return: the value of the volf ratio
637        """
[4bfd277]638        if self._model_info.VR is None:
[ce27e21]639            return 1.0
640        else:
[4bfd277]641            value, weight = self._dispersion_mesh()
642            whole, part = self._model_info.VR(*value)
643            return np.sum(weight * part) / np.sum(weight * whole)
[ce27e21]644
645    def set_dispersion(self, parameter, dispersion):
[fa5fd8d]646        # type: (str, weights.Dispersion) -> Dict[str, Any]
[ce27e21]647        """
648        Set the dispersion object for a model parameter
649
650        :param parameter: name of the parameter [string]
651        :param dispersion: dispersion object of type Dispersion
652        """
[fa800e72]653        if parameter in self.params:
[1780d59]654            # TODO: Store the disperser object directly in the model.
[56b2687]655            # The current method of relying on the sasview GUI to
[fa800e72]656            # remember them is kind of funky.
[1780d59]657            # Note: can't seem to get disperser parameters from sasview
[9c1a59c]658            # (1) Could create a sasview model that has not yet been
[1780d59]659            # converted, assign the disperser to one of its polydisperse
660            # parameters, then retrieve the disperser parameters from the
[9c1a59c]661            # sasview model.
662            # (2) Could write a disperser parameter retriever in sasview.
663            # (3) Could modify sasview to use sasmodels.weights dispersers.
[1780d59]664            # For now, rely on the fact that the sasview only ever uses
665            # new dispersers in the set_dispersion call and create a new
666            # one instead of trying to assign parameters.
[ce27e21]667            self.dispersion[parameter] = dispersion.get_pars()
668        else:
669            raise ValueError("%r is not a dispersity or orientation parameter")
670
[aa4946b]671    def _dispersion_mesh(self):
[fa5fd8d]672        # type: () -> List[Tuple[np.ndarray, np.ndarray]]
[ce27e21]673        """
674        Create a mesh grid of dispersion parameters and weights.
675
676        Returns [p1,p2,...],w where pj is a vector of values for parameter j
677        and w is a vector containing the products for weights for each
678        parameter set in the vector.
679        """
[4bfd277]680        pars = [self._get_weights(p)
681                for p in self._model_info.parameters.call_parameters
682                if p.type == 'volume']
[9eb3632]683        return dispersion_mesh(self._model_info, pars)
[ce27e21]684
685    def _get_weights(self, par):
[fa5fd8d]686        # type: (Parameter) -> Tuple[np.ndarray, np.ndarray]
[de0c4ba]687        """
[fb5914f]688        Return dispersion weights for parameter
[de0c4ba]689        """
[fa5fd8d]690        if par.name not in self.params:
691            if par.name == self.multiplicity_info.control:
[4edec6f]692                return [self.multiplicity], [1.0]
[fa5fd8d]693            else:
[8f93522]694                # For hidden parameters use the default value.
695                value = self._model_info.parameters.defaults.get(par.name, np.NaN)
696                return [value], [1.0]
[fa5fd8d]697        elif par.polydisperse:
[fb5914f]698            dis = self.dispersion[par.name]
[9c1a59c]699            if dis['type'] == 'array':
700                value, weight = dis['values'], dis['weights']
701            else:
702                value, weight = weights.get_weights(
703                    dis['type'], dis['npts'], dis['width'], dis['nsigmas'],
704                    self.params[par.name], par.limits, par.relative_pd)
[fb5914f]705            return value, weight / np.sum(weight)
706        else:
[4edec6f]707            return [self.params[par.name]], [1.0]
[ce27e21]708
[fb5914f]709def test_model():
[fa5fd8d]710    # type: () -> float
[4d76711]711    """
712    Test that a sasview model (cylinder) can be run.
713    """
714    Cylinder = _make_standard_model('cylinder')
[fb5914f]715    cylinder = Cylinder()
[b32dafd]716    return cylinder.evalDistribution([0.1, 0.1])
[de97440]717
[8f93522]718def test_structure_factor():
719    # type: () -> float
720    """
721    Test that a sasview model (cylinder) can be run.
722    """
723    Model = _make_standard_model('hardsphere')
724    model = Model()
725    value = model.evalDistribution([0.1, 0.1])
726    if np.isnan(value):
727        raise ValueError("hardsphere returns null")
728
[04045f4]729def test_rpa():
730    # type: () -> float
731    """
732    Test that a sasview model (cylinder) can be run.
733    """
734    RPA = _make_standard_model('rpa')
735    rpa = RPA(3)
[b32dafd]736    return rpa.evalDistribution([0.1, 0.1])
[04045f4]737
[4d76711]738
739def test_model_list():
[fa5fd8d]740    # type: () -> None
[4d76711]741    """
742    Make sure that all models build as sasview models.
743    """
744    from .exception import annotate_exception
745    for name in core.list_models():
746        try:
747            _make_standard_model(name)
748        except:
749            annotate_exception("when loading "+name)
750            raise
751
[c95dfc63]752def test_old_name():
753    # type: () -> None
754    """
755    Load and run cylinder model from sas.models.CylinderModel
756    """
757    if not SUPPORT_OLD_STYLE_PLUGINS:
758        return
759    try:
760        # if sasview is not on the path then don't try to test it
761        import sas
762    except ImportError:
763        return
764    load_standard_models()
765    from sas.models.CylinderModel import CylinderModel
766    CylinderModel().evalDistribution([0.1, 0.1])
767
[fb5914f]768if __name__ == "__main__":
[ea05c87]769    print("cylinder(0.1,0.1)=%g"%test_model())
Note: See TracBrowser for help on using the repository browser.